• Docs >
  • How to do DistributedDataParallel(DDP)
Shortcuts

如何进行分布式数据并行(DDP)

本文档展示了如何在xla中使用torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生xla数据并行方法的区别。您可以找到一个最小可运行的示例这里

背景 / 动机

客户长期以来一直要求能够使用PyTorch的DistributedDataParallel API与xla。现在我们将其作为一个实验性功能启用。

如何使用分布式数据并行

对于那些从PyTorch的eager模式切换到XLA的用户,这里列出了将您的eager DDP模型转换为XLA模型所需的所有更改。我们假设您已经知道如何在单个设备上使用XLA。

  1. 导入xla特定的分布式包:

import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
  1. 初始化xla进程组,类似于其他进程组,如nccl和gloo。

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. 如果需要,请使用xla特定的API来获取rank和world_size。

new_rank = xr.global_ordinal()
world_size = xr.world_size()
  1. gradient_as_bucket_view=True传递给DDP包装器。

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. 最后使用xla特定的启动器启动您的模型。

torch_xla.launch(demo_fn)

在这里,我们将所有内容整合在一起(这个例子实际上取自DDP教程)。你编写代码的方式与急切体验非常相似。只是在单个设备上添加了xla特定的处理,并对你的脚本进行了上述五项更改。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xr.global_ordinal()
    assert new_rank == rank
    world_size = xr.world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    torch_xla.launch(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

基准测试

使用假数据的Resnet50

以下结果是使用命令:python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA收集的。统计指标是通过使用此pull request中的脚本生成的。速率的单位是每秒图像数。

Type Mean Median 90th % Std Dev CV
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我们原生分布式数据并行方法与DistributedDataParallel包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到DDP包装器在跟踪DDP运行时引入了额外的开销,这个结果似乎是合理的。

使用假数据的MNIST

以下结果是在TPU VM V3-8环境中使用命令python test/test_train_mp_mnist.py --fake_data收集的,使用了ToT PyTorch和PyTorch/XLA。统计指标是通过使用此pull request中的脚本生成的。速率的单位是每秒处理的图像数。

Type Mean Median 90th % Std Dev CV
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我们本地的分布式数据并行方法与DistributedDataParallel包装器的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。这里我们比较第90百分位数,因为数据集较小,前几轮受数据加载的影响较大。这种减速是巨大的,但考虑到模型较小,这是合理的。额外的DDP运行时跟踪开销很难分摊。

使用真实数据的MNIST

以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行的命令收集的:python test/test_train_mp_mnist.py --logdir mnist/

learning_curves

我们可以观察到,DDP包装器的收敛速度比原生XLA方法慢,尽管它在最后仍然达到了97.48%的高准确率。(原生方法达到了99%。)

免责声明

此功能仍处于实验阶段,正在积极开发中。使用时请谨慎,并随时向xla github repo提交任何错误。对于那些对原生xla数据并行方法感兴趣的人,这里是教程

以下是一些正在调查中的已知问题:

  • gradient_as_bucket_view=True 需要强制执行。

  • 在使用torch.utils.data.DataLoader时存在一些问题。​​test_train_mp_mnist.py在使用真实数据时会在退出前崩溃。

PyTorch XLA中的完全分片数据并行(FSDP)

PyTorch XLA 中的完全分片数据并行(FSDP)是一种用于在数据并行工作器之间分片模块参数的实用工具。

示例用法:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。

备注:

  • XlaFullyShardedDataParallel 类支持在 https://arxiv.org/abs/1910.02054 中提到的 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。

    • ZeRO-3 优化器应通过嵌套 FSDP 实现,并使用 reshard_after_forward=True。有关示例,请参见 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py

    • 对于无法适应单个TPU内存或主机CPU内存的大型模型,应该将子模块构建与内部FSDP包装交错进行。参见``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_ 以获取示例。

  • 提供了一个简单的包装器 checkpoint_module(基于 torch_xla.utils.checkpoint.checkpoint 来自 https://github.com/pytorch/xla/pull/3524),用于在给定的 nn.Module 实例上执行 梯度检查点。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 以获取示例。

  • 自动包装子模块:除了手动嵌套FSDP包装外,还可以指定一个auto_wrap_policy参数来自动用内部FSDP包装子模块。torch_xla.distributed.fsdp.wrap中的size_based_auto_wrap_policyauto_wrap_policy可调用对象的一个示例,该策略会包装参数数量超过100M的层。torch_xla.distributed.fsdp.wrap中的transformer_auto_wrap_policy是用于类似transformer模型架构的auto_wrap_policy可调用对象的示例。

例如,要自动将所有torch.nn.Conv2d子模块用内部FSDP包装,可以使用:

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定一个auto_wrapper_callable参数,用于为子模块使用自定义的可调用包装器(默认包装器是XlaFullyShardedDataParallel类本身)。例如,可以使用以下方法将梯度检查点(即激活检查点/重新计算)应用于每个自动包装的子模块。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在优化器步进时,直接调用 optimizer.step,不要调用 xm.optimizer_step。后者会跨等级减少梯度,这对于FSDP(参数已经分片)来说是不必要的。

  • 在训练过程中保存模型和优化器检查点时,每个训练过程都需要保存其自己的(分片的)模型和优化器状态字典(使用master_only=False并在xm.save中为每个等级设置不同的路径)。在恢复时,需要加载相应等级的检查点。

  • 请同时保存 model.get_shard_metadata()model.state_dict() 如下所示,并使用 consolidate_sharded_model_checkpoints 将分片模型检查点拼接成一个完整的模型状态字典。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.py 以获取示例。 .. code-block:: python3

    ckpt = {

    ‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),

    } ckpt_path = f’/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)

  • 检查点合并脚本也可以从命令行启动,如下所示。 .. code-block:: bash

    # 通过命令行工具合并保存的检查点 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”

该类的实现主要受到并大部分遵循了https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.htmlfairscale.nn.FullyShardedDataParallel的结构。与fairscale.nn.FullyShardedDataParallel最大的区别之一是在XLA中没有显式的参数存储,因此在这里我们采用了一种不同的方法来释放ZeRO-3的完整参数。


MNIST 和 ImageNet 上的示例训练脚本

安装

FSDP 在 PyTorch/XLA 1.12 版本及更新的 nightly 版本中可用。请参考 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。

克隆 PyTorch/XLA 仓库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在v3-8 TPU上训练MNIST

它在2个epochs中达到了约98.9的准确率:

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此脚本在最后自动测试检查点合并。您也可以通过以下方式手动合并分片检查点

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在v3-8 TPU上使用ResNet-50训练ImageNet

它在100个周期内达到了大约75.9的准确率;下载ImageNet-1k/datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

你也可以添加--use_gradient_checkpointing(需要与--use_nested_fsdp--auto_wrap_policy一起使用)来对残差块应用梯度检查点。


在TPU pod上运行的训练脚本示例(具有100亿参数)

为了训练无法适应单个TPU的大型模型,应在构建整个模型时应用自动包装或手动包装内部FSDP的子模块,以实现ZeRO-3算法。

请参阅https://github.com/ronghanghu/vit_10b_fsdp_example以了解使用此XLA FSDP PR进行Vision Transformer (ViT)模型分片训练的示例。