PyTorch
内容
PyTorch¶
Skorch 将 Scikit-learn API 引入到 PyTorch_。Skorch 允许将 PyTorch 模型封装为与 Scikit-learn 兼容的估计器。这意味着,通过 Skorch 封装的 PyTorch 模型可以与 Dask-ML API 的其他部分一起使用。例如,在使用 Dask-ML 的 HyperbandSearchCV 或 Incremental 时,可以在封装后与 PyTorch 一起使用。
我们鼓励查看 Skorch 文档以获取完整详情。
示例用法¶
首先,让我们创建一个普通的 PyTorch 模型:
import torch.nn as nn
import torch.nn.functional as F
class ShallowNet(nn.Module):
def __init__(self, n_features=5):
super().__init__()
self.layer1 = nn.Linear(n_features, 1)
def forward(self, x):
return F.relu(self.layer1(x))
通过这个,使用 Skorch 变得非常简单:
from skorch import NeuralNetRegressor
import torch.optim as optim
niceties = {
"callbacks": False,
"warm_start": False,
"train_split": None,
"max_epochs": 1,
}
model = NeuralNetRegressor(
module=ShallowNet,
module__n_features=5,
criterion=nn.MSELoss,
optimizer=optim.SGD,
optimizer__lr=0.1,
optimizer__momentum=0.9,
batch_size=64,
**niceties,
)
PyTorch nn.Module 的每个参数都带有 module__ 前缀,优化器(如 optim.SGD 接受 lr 和 momentum 参数)也是如此。niceties 确保 Skorch 使用所有数据进行训练,并且不会打印过多的日志。
现在,这个模型可以与 Dask-ML 一起使用。例如,可以执行以下操作:
使用 PyTorch 与 Dask-ML 的模型选择功能,包括
HyperbandSearchCV。使用 PyTorch 与 Dask-ML 的
Incremental。