开始使用

开始使用训练操作符

本指南描述了如何开始使用训练操作符并运行一些简单的示例。

前提条件

您需要安装以下组件以运行示例:

开始使用 PyTorchJob

您可以使用Python SDK创建您的第一个训练操作符分布式PyTorchJob。定义实现端到端模型训练的训练函数。每个工作节点将在适当的Kubernetes Pod上执行此函数。通常,这个函数包含下载数据集、创建模型和训练模型的逻辑。

训练操作符将自动为适当的 PyTorchJob 工作者设置 WORLD_SIZERANK 以执行 PyTorch 分布式数据并行 (DDP)

如果您将训练操作符作为Kubeflow平台的一部分进行安装,您可以打开一个新的 Kubeflow Notebook来运行此脚本。如果您 单独安装训练操作符,请确保您 配置本地 kubeconfig 以访问您安装了训练操作符的Kubernetes集群。

def train_func():
    import torch
    import torch.nn.functional as F
    from torch.utils.data import DistributedSampler
    from torchvision import datasets, transforms
    import torch.distributed as dist

    # [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
    dist.init_process_group(backend="nccl")
    Distributor = torch.nn.parallel.DistributedDataParallel
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    print(
        "Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )

    # [2] Create PyTorch CNN Model.
    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
            self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
            self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
            self.fc2 = torch.nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    # [3] Attach model to the correct GPU device and distributor.
    device = torch.device(f"cuda:{local_rank}")
    model = Net().to(device)
    model = Distributor(model)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    # [4] Setup FashionMNIST dataloader and distribute data across PyTorchJob workers.
    dataset = datasets.FashionMNIST(
        "./data",
        download=True,
        train=True,
        transform=transforms.Compose([transforms.ToTensor()]),
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=128,
        sampler=DistributedSampler(dataset),
    )

    # [5] Start model Training.
    for epoch in range(3):
        for batch_idx, (data, target) in enumerate(train_loader):
            # Attach Tensors to the device.
            data = data.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0 and dist.get_rank() == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )


from kubeflow.training import TrainingClient

# Start PyTorchJob with 3 Workers and 1 GPU per Worker (e.g. multi-node, multi-worker job).
TrainingClient().create_job(
    name="pytorch-ddp",
    train_func=train_func,
    num_procs_per_worker="auto",
    num_workers=3,
    resources_per_worker={"gpu": "1"},
)

开始使用 TFJob

类似于 PyTorchJob 示例,您可以使用 Python SDK 创建您的第一个分布式 TensorFlow 作业。运行以下脚本以使用预先创建的 Docker 镜像创建 TFJob: docker.io/kubeflow/tf-mnist-with-summaries:latest,该镜像包含 分布式 TensorFlow 代码

from kubeflow.training import TrainingClient

TrainingClient().create_job(
    name="tensorflow-dist",
    job_kind="TFJob",
    base_image="docker.io/kubeflow/tf-mnist-with-summaries:latest",
    num_workers=3,
)

运行以下API以获取您的TFJob的日志:

TrainingClient().get_job_logs(
    name="tensorflow-dist",
    job_kind="TFJob",
    follow=True,
)

下一步

反馈

此页面有帮助吗?