纯PyTorch中的多GPU训练
对于许多大规模的真实世界数据集,可能需要在多个GPU上进行扩展训练。本教程介绍了如何在PyG中通过PyTorch使用torch.nn.parallel.DistributedDataParallel设置多GPU训练管道,而无需任何其他第三方库(如PyTorch Lightning)。请注意,此方法基于数据并行。这意味着每个GPU运行模型的相同副本;如果您希望在设备之间扩展模型,您可能需要查看PyTorch FSDP。数据并行允许您通过在GPU之间聚合梯度来增加模型的批量大小,然后在每个模型副本中共享相同的优化器步骤。普林斯顿大学的这个DDP+MNIST教程有一些很好的过程图示。
具体来说,本教程展示了如何在Reddit数据集上训练一个GraphSAGE GNN模型。
为此,我们将使用torch.nn.parallel.DistributedDataParallel来在所有可用的GPU上进行扩展训练。
我们将通过从Python代码中生成多个进程来实现这一点,这些进程都将执行相同的函数。
在每个进程中,我们设置模型实例,并通过使用NeighborLoader来输入数据。
通过将模型包装在torch.nn.parallel.DistributedDataParallel中(如其官方教程所述)来同步梯度,这反过来依赖于torch.distributed的IPC功能。
注意
本教程的完整脚本可以在examples/multi_gpu/distributed_sampling.py找到。
定义一个可生成的运行器
为了创建我们的训练脚本,我们使用了PyTorch提供的对原生Python multiprocessing模块的封装。
在这里,world_size对应于我们将同时使用的GPU数量。
torch.multiprocessing.spawn()将负责生成world_size个进程。
每个进程将加载相同的脚本作为模块,并随后执行run()函数:
from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp
def run(rank: int, world_size: int, dataset: Reddit):
pass
if __name__ == '__main__':
dataset = Reddit('./data/Reddit')
world_size = torch.cuda.device_count()
mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)
请注意,我们在生成任何进程之前初始化数据集。
这样,我们只初始化数据集一次,其中的任何数据将通过torch.multiprocessing自动移动到共享内存中,因此进程不需要创建自己的数据副本。
此外,请注意run()函数如何接受rank作为其第一个参数。
这个参数不是由我们显式提供的。
它对应于由PyTorch注入的进程ID(从0开始)。
稍后我们将使用它为每个rank选择唯一的GPU。
有了这个,我们可以开始实现我们的可生成运行器函数。
第一步是使用torch.distributed初始化一个进程组。
到目前为止,进程之间并不相互了解,我们使用nccl协议设置了一个硬编码的服务器地址进行集合。
更多详细信息可以在“使用PyTorch编写分布式应用程序”教程中找到:
import os
import torch.distributed as dist
import torch
def run(rank: int, world_size: int, dataset: Reddit):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
接下来,我们将训练索引分割成world_size多个块,每个GPU一个块,并初始化NeighborLoader类,使其仅在其特定的训练集块上操作:
from torch_geometric.loader import NeighborLoader
def run(rank: int, world_size: int, dataset: Reddit):
...
data = dataset[0]
train_index = data.train_mask.nonzero().view(-1)
train_index = train_index.split(train_index.size(0) // world_size)[rank]
train_loader = NeighborLoader(
data,
input_nodes=train_index,
num_neighbors=[25, 10],
batch_size=1024,
num_workers=4,
shuffle=True,
)
请注意,我们的run()函数是为每个等级调用的,这意味着每个等级都持有一个单独的NeighborLoader实例。
同样地,我们为评估创建了一个NeighborLoader实例。
为了简单起见,我们只在0等级上进行此操作,这样指标的计算就不需要在不同进程之间进行通信。
我们建议查看torchmetrics包以了解指标的分布式计算。
def run(rank: int, world_size: int, dataset: Reddit):
...
if rank == 0:
val_index = data.val_mask.nonzero().view(-1)
val_loader = NeighborLoader(
data,
input_nodes=val_index,
num_neighbors=[25, 10],
batch_size=1024,
num_workers=4,
shuffle=False,
)
现在我们已经定义了数据加载器,我们初始化了GraphSAGE模型,并将其包装在torch.nn.parallel.DistributedDataParallel中。
我们还使用rank作为完整设备标识符的快捷方式,将模型移动到其专用的GPU上。
模型上的包装器管理每个rank之间的通信,并在更新所有rank的模型参数之前同步所有rank的梯度:
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE
def run(rank: int, world_size: int, dataset: Reddit):
...
torch.manual_seed(12345)
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=256,
num_layers=2,
out_channels=dataset.num_classes,
).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
最后,我们可以设置我们的优化器并定义我们的训练循环,这与通常的单GPU训练循环类似——不同进程之间的梯度和模型权重同步的实际魔法将在DistributedDataParallel背后进行:
import torch.nn.functional as F
def run(rank: int, world_size: int, dataset: Reddit):
...
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 11):
model.train()
for batch in train_loader:
batch = batch.to(rank)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
在每个训练周期后,我们评估并报告验证指标。
如前所述,我们仅在单个GPU上进行此操作。
为了同步所有进程并确保模型权重已更新,我们需要调用torch.distributed.barrier():
dist.barrier()
if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
if rank == 0:
model.eval()
count = correct = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(rank)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
correct += (pred == batch.y[:batch.batch_size]).sum()
count += batch.batch_size
print(f'Validation Accuracy: {correct/count:.4f}')
dist.barrier()
训练完成后,我们可以通过以下方式清理进程并销毁进程组:
dist.destroy_process_group()
就是这样。 将所有内容整合在一起,就得到了一个可用的多GPU示例,它遵循与单GPU训练相似的训练流程。 你可以通过查看examples/multi_gpu/distributed_sampling.py自己运行显示的教程。