纯PyTorch中的多GPU训练

对于许多大规模的真实世界数据集,可能需要在多个GPU上进行扩展训练。本教程介绍了如何在中通过使用torch.nn.parallel.DistributedDataParallel设置多GPU训练管道,而无需任何其他第三方库(如)。请注意,此方法基于数据并行。这意味着每个GPU运行模型的相同副本;如果您希望在设备之间扩展模型,您可能需要查看PyTorch FSDP。数据并行允许您通过在GPU之间聚合梯度来增加模型的批量大小,然后在每个模型副本中共享相同的优化器步骤。普林斯顿大学的这个DDP+MNIST教程有一些很好的过程图示。

具体来说,本教程展示了如何在Reddit数据集上训练一个GraphSAGE GNN模型。 为此,我们将使用torch.nn.parallel.DistributedDataParallel来在所有可用的GPU上进行扩展训练。 我们将通过从代码中生成多个进程来实现这一点,这些进程都将执行相同的函数。 在每个进程中,我们设置模型实例,并通过使用NeighborLoader来输入数据。 通过将模型包装在torch.nn.parallel.DistributedDataParallel中(如其官方教程所述)来同步梯度,这反过来依赖于torch.distributed的IPC功能。

注意

本教程的完整脚本可以在examples/multi_gpu/distributed_sampling.py找到。

定义一个可生成的运行器

为了创建我们的训练脚本,我们使用了提供的对原生 multiprocessing模块的封装。 在这里,world_size对应于我们将同时使用的GPU数量。 torch.multiprocessing.spawn()将负责生成world_size个进程。 每个进程将加载相同的脚本作为模块,并随后执行run()函数:

from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp

def run(rank: int, world_size: int, dataset: Reddit):
    pass

if __name__ == '__main__':
    dataset = Reddit('./data/Reddit')
    world_size = torch.cuda.device_count()
    mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)

请注意,我们在生成任何进程之前初始化数据集。 这样,我们只初始化数据集一次,其中的任何数据将通过torch.multiprocessing自动移动到共享内存中,因此进程不需要创建自己的数据副本。 此外,请注意run()函数如何接受rank作为其第一个参数。 这个参数不是由我们显式提供的。 它对应于由注入的进程ID(从0开始)。 稍后我们将使用它为每个rank选择唯一的GPU。

有了这个,我们可以开始实现我们的可生成运行器函数。 第一步是使用torch.distributed初始化一个进程组。 到目前为止,进程之间并不相互了解,我们使用nccl协议设置了一个硬编码的服务器地址进行集合。 更多详细信息可以在“使用PyTorch编写分布式应用程序”教程中找到:

import os
import torch.distributed as dist
import torch

def run(rank: int, world_size: int, dataset: Reddit):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

接下来,我们将训练索引分割成world_size多个块,每个GPU一个块,并初始化NeighborLoader类,使其仅在其特定的训练集块上操作:

from torch_geometric.loader import NeighborLoader

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    data = dataset[0]

    train_index = data.train_mask.nonzero().view(-1)
    train_index = train_index.split(train_index.size(0) // world_size)[rank]

    train_loader = NeighborLoader(
        data,
        input_nodes=train_index,
        num_neighbors=[25, 10],
        batch_size=1024,
        num_workers=4,
        shuffle=True,
    )

请注意,我们的run()函数是为每个等级调用的,这意味着每个等级都持有一个单独的NeighborLoader实例。

同样地,我们为评估创建了一个NeighborLoader实例。 为了简单起见,我们只在0等级上进行此操作,这样指标的计算就不需要在不同进程之间进行通信。 我们建议查看torchmetrics包以了解指标的分布式计算。

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    if rank == 0:
        val_index = data.val_mask.nonzero().view(-1)
        val_loader = NeighborLoader(
            data,
            input_nodes=val_index,
            num_neighbors=[25, 10],
            batch_size=1024,
            num_workers=4,
            shuffle=False,
        )

现在我们已经定义了数据加载器,我们初始化了GraphSAGE模型,并将其包装在torch.nn.parallel.DistributedDataParallel中。 我们还使用rank作为完整设备标识符的快捷方式,将模型移动到其专用的GPU上。 模型上的包装器管理每个rank之间的通信,并在更新所有rank的模型参数之前同步所有rank的梯度:

from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    torch.manual_seed(12345)
    model = GraphSAGE(
        in_channels=dataset.num_features,
        hidden_channels=256,
        num_layers=2,
        out_channels=dataset.num_classes,
    ).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

最后,我们可以设置我们的优化器并定义我们的训练循环,这与通常的单GPU训练循环类似——不同进程之间的梯度和模型权重同步的实际魔法将在DistributedDataParallel背后进行:

import torch.nn.functional as F

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, 11):
        model.train()
        for batch in train_loader:
            batch = batch.to(rank)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            loss = F.cross_entropy(out, batch.y[:batch.batch_size])
            loss.backward()
            optimizer.step()

在每个训练周期后,我们评估并报告验证指标。 如前所述,我们仅在单个GPU上进行此操作。 为了同步所有进程并确保模型权重已更新,我们需要调用torch.distributed.barrier()

dist.barrier()

if rank == 0:
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

if rank == 0:
    model.eval()
    count = correct = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(rank)
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            pred = out.argmax(dim=-1)
            correct += (pred == batch.y[:batch.batch_size]).sum()
            count += batch.batch_size
    print(f'Validation Accuracy: {correct/count:.4f}')

dist.barrier()

训练完成后,我们可以通过以下方式清理进程并销毁进程组:

dist.destroy_process_group()

就是这样。 将所有内容整合在一起,就得到了一个可用的多GPU示例,它遵循与单GPU训练相似的训练流程。 你可以通过查看examples/multi_gpu/distributed_sampling.py自己运行显示的教程。