PyTorch

PyTorch

Skorch 将 Scikit-learn API 引入到 PyTorch_。Skorch 允许将 PyTorch 模型封装为与 Scikit-learn 兼容的估计器。这意味着,通过 Skorch 封装的 PyTorch 模型可以与 Dask-ML API 的其他部分一起使用。例如,在使用 Dask-ML 的 HyperbandSearchCVIncremental 时,可以在封装后与 PyTorch 一起使用。

我们鼓励查看 Skorch 文档以获取完整详情。

示例用法

首先,让我们创建一个普通的 PyTorch 模型:

import torch.nn as nn
import torch.nn.functional as F

class ShallowNet(nn.Module):
    def __init__(self, n_features=5):
        super().__init__()
        self.layer1 = nn.Linear(n_features, 1)

    def forward(self, x):
        return F.relu(self.layer1(x))

通过这个,使用 Skorch 变得非常简单:

from skorch import NeuralNetRegressor
import torch.optim as optim

niceties = {
    "callbacks": False,
    "warm_start": False,
    "train_split": None,
    "max_epochs": 1,
}

model = NeuralNetRegressor(
    module=ShallowNet,
    module__n_features=5,
    criterion=nn.MSELoss,
    optimizer=optim.SGD,
    optimizer__lr=0.1,
    optimizer__momentum=0.9,
    batch_size=64,
    **niceties,
)

PyTorch nn.Module 的每个参数都带有 module__ 前缀,优化器(如 optim.SGD 接受 lrmomentum 参数)也是如此。niceties 确保 Skorch 使用所有数据进行训练,并且不会打印过多的日志。

现在,这个模型可以与 Dask-ML 一起使用。例如,可以执行以下操作: