如何进行分布式数据并行(DDP)¶
本文档展示了如何在xla中使用torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生xla数据并行方法的区别。您可以找到一个最小可运行的示例这里。
背景 / 动机¶
客户长期以来一直要求能够使用PyTorch的DistributedDataParallel API与xla。现在我们将其作为一个实验性功能启用。
如何使用分布式数据并行¶
对于那些从PyTorch的eager模式切换到XLA的用户,这里列出了将您的eager DDP模型转换为XLA模型所需的所有更改。我们假设您已经知道如何在单个设备上使用XLA。
导入xla特定的分布式包:
import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
初始化xla进程组,类似于其他进程组,如nccl和gloo。
dist.init_process_group("xla", rank=rank, world_size=world_size)
如果需要,请使用xla特定的API来获取rank和world_size。
new_rank = xr.global_ordinal()
world_size = xr.world_size()
将
gradient_as_bucket_view=True传递给DDP包装器。
ddp_model = DDP(model, gradient_as_bucket_view=True)
最后使用xla特定的启动器启动您的模型。
torch_xla.launch(demo_fn)
在这里,我们将所有内容整合在一起(这个例子实际上取自DDP教程)。你编写代码的方式与急切体验非常相似。只是在单个设备上添加了xla特定的处理,并对你的脚本进行了上述五项更改。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xr.global_ordinal()
assert new_rank == rank
world_size = xr.world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
torch_xla.launch(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
基准测试¶
使用假数据的Resnet50¶
以下结果是使用命令:python
test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA收集的。统计指标是通过使用此pull
request中的脚本生成的。速率的单位是每秒图像数。
| Type | Mean | Median | 90th % | Std Dev | CV |
| xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
| DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
我们原生分布式数据并行方法与DistributedDataParallel包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到DDP包装器在跟踪DDP运行时引入了额外的开销,这个结果似乎是合理的。
使用假数据的MNIST¶
以下结果是在TPU VM V3-8环境中使用命令python
test/test_train_mp_mnist.py --fake_data收集的,使用了ToT PyTorch和PyTorch/XLA。统计指标是通过使用此pull request中的脚本生成的。速率的单位是每秒处理的图像数。
| Type | Mean | Median | 90th % | Std Dev | CV |
| xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
| DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
我们本地的分布式数据并行方法与DistributedDataParallel包装器的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。这里我们比较第90百分位数,因为数据集较小,前几轮受数据加载的影响较大。这种减速是巨大的,但考虑到模型较小,这是合理的。额外的DDP运行时跟踪开销很难分摊。
使用真实数据的MNIST¶
以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行的命令收集的:python
test/test_train_mp_mnist.py --logdir mnist/。
我们可以观察到,DDP包装器的收敛速度比原生XLA方法慢,尽管它在最后仍然达到了97.48%的高准确率。(原生方法达到了99%。)
免责声明¶
此功能仍处于实验阶段,正在积极开发中。使用时请谨慎,并随时向xla github repo提交任何错误。对于那些对原生xla数据并行方法感兴趣的人,这里是教程。
以下是一些正在调查中的已知问题:
gradient_as_bucket_view=True需要强制执行。在使用
torch.utils.data.DataLoader时存在一些问题。test_train_mp_mnist.py在使用真实数据时会在退出前崩溃。
PyTorch XLA中的完全分片数据并行(FSDP)¶
PyTorch XLA 中的完全分片数据并行(FSDP)是一种用于在数据并行工作器之间分片模块参数的实用工具。
示例用法:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。
备注:
XlaFullyShardedDataParallel类支持在 https://arxiv.org/abs/1910.02054 中提到的 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。ZeRO-3 优化器应通过嵌套 FSDP 实现,并使用
reshard_after_forward=True。有关示例,请参见test/test_train_mp_mnist_fsdp_with_ckpt.py和test/test_train_mp_imagenet_fsdp.py。对于无法适应单个TPU内存或主机CPU内存的大型模型,应该将子模块构建与内部FSDP包装交错进行。参见``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_ 以获取示例。
提供了一个简单的包装器
checkpoint_module(基于torch_xla.utils.checkpoint.checkpoint来自 https://github.com/pytorch/xla/pull/3524),用于在给定的nn.Module实例上执行 梯度检查点。请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py和test/test_train_mp_imagenet_fsdp.py以获取示例。自动包装子模块:除了手动嵌套FSDP包装外,还可以指定一个
auto_wrap_policy参数来自动用内部FSDP包装子模块。torch_xla.distributed.fsdp.wrap中的size_based_auto_wrap_policy是auto_wrap_policy可调用对象的一个示例,该策略会包装参数数量超过100M的层。torch_xla.distributed.fsdp.wrap中的transformer_auto_wrap_policy是用于类似transformer模型架构的auto_wrap_policy可调用对象的示例。
例如,要自动将所有torch.nn.Conv2d子模块用内部FSDP包装,可以使用:
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定一个auto_wrapper_callable参数,用于为子模块使用自定义的可调用包装器(默认包装器是XlaFullyShardedDataParallel类本身)。例如,可以使用以下方法将梯度检查点(即激活检查点/重新计算)应用于每个自动包装的子模块。
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
在优化器步进时,直接调用
optimizer.step,不要调用xm.optimizer_step。后者会跨等级减少梯度,这对于FSDP(参数已经分片)来说是不必要的。在训练过程中保存模型和优化器检查点时,每个训练过程都需要保存其自己的(分片的)模型和优化器状态字典(使用
master_only=False并在xm.save中为每个等级设置不同的路径)。在恢复时,需要加载相应等级的检查点。请同时保存
model.get_shard_metadata()和model.state_dict()如下所示,并使用consolidate_sharded_model_checkpoints将分片模型检查点拼接成一个完整的模型状态字典。请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py以获取示例。 .. code-block:: python3- ckpt = {
‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),
} ckpt_path = f’/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)
检查点合并脚本也可以从命令行启动,如下所示。 .. code-block:: bash
# 通过命令行工具合并保存的检查点 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
该类的实现主要受到并大部分遵循了https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html中fairscale.nn.FullyShardedDataParallel的结构。与fairscale.nn.FullyShardedDataParallel最大的区别之一是在XLA中没有显式的参数存储,因此在这里我们采用了一种不同的方法来释放ZeRO-3的完整参数。
MNIST 和 ImageNet 上的示例训练脚本¶
最小示例:``examples/fsdp/train_resnet_fsdp_auto_wrap.py` <https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py>`_
MNIST: ``test/test_train_mp_mnist_fsdp_with_ckpt.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>`_ (它还测试了检查点合并)
ImageNet: ``test/test_train_mp_imagenet_fsdp.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
安装¶
FSDP 在 PyTorch/XLA 1.12 版本及更新的 nightly 版本中可用。请参考 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。
克隆 PyTorch/XLA 仓库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在v3-8 TPU上训练MNIST¶
它在2个epochs中达到了约98.9的准确率:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
此脚本在最后自动测试检查点合并。您也可以通过以下方式手动合并分片检查点
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在v3-8 TPU上使用ResNet-50训练ImageNet¶
它在100个周期内达到了大约75.9的准确率;下载ImageNet-1k到/datasets/imagenet-1k:
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
你也可以添加--use_gradient_checkpointing(需要与--use_nested_fsdp或--auto_wrap_policy一起使用)来对残差块应用梯度检查点。
在TPU pod上运行的训练脚本示例(具有100亿参数)¶
为了训练无法适应单个TPU的大型模型,应在构建整个模型时应用自动包装或手动包装内部FSDP的子模块,以实现ZeRO-3算法。
请参阅https://github.com/ronghanghu/vit_10b_fsdp_example以了解使用此XLA FSDP PR进行Vision Transformer (ViT)模型分片训练的示例。