• Docs >
  • PyTorch/XLA SPMD User Guide
Shortcuts

PyTorch/XLA SPMD 用户指南

在本用户指南中,我们讨论了GSPMD如何集成到PyTorch/XLA中,并提供了一个设计概述,以说明SPMD分片注释API及其构造的工作原理。

什么是PyTorch/XLA SPMD?

GSPMD 是一个用于常见机器学习工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单设备程序转换为带有适当集合操作的分区程序。此功能使开发人员能够像在单个大型设备上一样编写 PyTorch 程序,而无需任何自定义的分片计算操作和/或集合通信来进行扩展。

alt_text

*图1. 两种不同执行策略的比较,(a) 非SPMD和(b) SPMD。*

如何使用PyTorch/XLA SPMD?

这是一个使用SPMD的简单示例

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh


# Enable XLA SPMD execution mode.
xr.use_spmd()


# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))


t = torch.randn(8, 4).to(xm.xla_device())


# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

让我们逐一解释这些概念

SPMD模式

为了使用SPMD,你需要通过xr.use_spmd()来启用它。在SPMD模式下,只有一个逻辑设备。分布式计算和集合操作由mark_sharding处理。请注意,用户不能将SPMD与其他分布式库混合使用。

网格

对于给定的设备集群,物理网格是互连拓扑的表示。

  1. mesh_shape 是一个元组,它将乘以物理设备的总数。

  2. device_ids 几乎总是 np.array(range(num_devices))

  3. 还鼓励用户为每个网格维度命名。在上面的例子中,第一个网格维度是data维度,第二个网格维度是model维度。

您还可以通过以下方式查看更多网格信息

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

分区规范

partition_spec 的秩与输入张量相同。每个维度描述了相应的输入张量维度如何在设备网格上进行分片。在上面的例子中,张量 t 的第一个维度在 data 维度上进行分片,第二个维度在 model 维度上进行分片。

用户还可以对具有与网格形状不同维度的张量进行分片。

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

进一步阅读

  1. Example 使用SPMD来表达数据并行。

  2. Example 使用SPMD来表达FSDP(完全分片数据并行)。

  3. SPMD高级主题

  4. Spmd Distributed Checkpoint

通过SPMD实现的全分片数据并行(FSDP)

通过SPMD或FSDPv2实现的完全分片数据并行是一种工具,它以SPMD重新表达了著名的FSDP算法。是一个实验性功能,旨在为用户提供一个熟悉的界面,以享受SPMD带来的所有好处。设计文档在这里

请在进行之前查看SPMD用户指南。您还可以找到一个最小可运行的示例这里

示例用法:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))

# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))

# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。以下是一个自动包装每个DecoderLayer的示例。

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        decoder_only_model.DecoderLayer
    },
)
model = FSDPv2(
    model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

分片输出

为了确保XLA编译器正确实现FSDP算法,我们需要对权重和激活进行分片。这意味着需要对前向方法的输出进行分片。由于前向函数的输出可能会有所不同,我们提供了shard_output来在模块输出不属于这些类别的情况下对激活进行分片:

  1. 单个张量

  2. 一个张量元组,其中第0个元素是激活值。

示例用法:

def shard_output(output, mesh):
    xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))

model = FSDPv2(my_module, mesh, shard_output)

梯度检查点

目前,梯度检查点需要在FSDP包装器之前应用于模块。否则,递归循环进入子模块将导致无限循环。我们将在未来的版本中修复此问题。

示例用法:

from torch_xla.distributed.fsdp import checkpoint_module

model = FSDPv2(checkpoint_module(my_module), mesh)

HuggingFace Llama 2 示例

我们有一个HF Llama 2的分支,用于展示潜在的集成这里

PyTorch/XLA SPMD 高级主题

在本文档中,我们将介绍一些关于GSPMD的高级主题。在继续阅读本文档之前,请先阅读SPMD用户指南

PyTorch/XLA SPMD 采用单设备程序,将其分片并并行执行。SPMD 执行需要使用原生的 PyTorch DataLoader,它会将数据从主机同步传输到 XLA 设备。这会在每一步的输入数据传输期间阻塞训练。为了提高原生数据加载性能,我们使 PyTorch/XLA ParallelLoader 直接支持输入分片(src),当传递可选参数 _input_sharding_ 时:

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
        train_loader,  # wraps PyTorch DataLoader
        device,
          # assume 4d input and we want to shard at the batch dimension.
        input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))

如果批次中的每个元素形状不同,也可以为每个元素指定不同的input_sharding

# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
        train_loader,  # wraps PyTorch DataLoader
        device,
          # specify different sharding for each input of the batch.
        input_sharding={
          'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
          'y': xs.ShardingSpec(input_mesh, ('data', None))
        }
)

PyTorch/XLA 通常在张量定义后异步将张量数据从主机传输到设备。这是为了将数据传输与图追踪时间重叠。然而,由于 GSPMD 允许用户在张量定义后修改张量分片,我们需要一种优化来防止张量数据在主机和设备之间不必要的来回传输。我们引入了虚拟设备优化,这是一种在最终确定所有分片决策之前,先将张量数据放置在虚拟设备 SPMD:0 上,然后再上传到物理设备的技术。SPMD 模式中的每个张量数据都放置在虚拟设备 SPMD:0 上。虚拟设备作为 XLA 设备 XLA:0 暴露给用户,实际的分片位于物理设备上,如 TPU:0、TPU:1 等。

混合网格

Mesh很好地抽象了物理设备网格的构建方式。用户可以使用逻辑网格以任何形状和顺序排列设备。然而,可以根据物理拓扑定义一个性能更好的网格,特别是在涉及数据中心网络(DCN)跨切片连接时。HybridMesh创建了一个网格,在这种多切片环境中提供了良好的开箱即用性能。它接受ici_mesh_shape和dcn_mesh_shape,分别表示内部和外部网络的逻辑网格形状。

from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

在TPU Pod上运行SPMD

如果根据设备数量而不是某些硬编码常量来构建网格和分区规范,则无需更改代码即可从单个TPU主机切换到TPU Pod。要在TPU Pod上运行PyTorch/XLA工作负载,请参阅我们的PJRT指南中的Pods部分

XLAShardedTensor

xs.mark_sharding 是一个原地操作,它将分片注释附加到输入张量上,但它也返回一个 XLAShardedTensor Python 对象。

XLAShardedTensor [RFC] 的主要用例是使用分片规范对本地 torch.tensor(在单个设备上)进行注释。注释会立即进行,但张量的实际分片会延迟,因为计算是惰性执行的,除了输入张量会立即分片。一旦张量被注释并包装在 XLAShardedTensor 中,它可以作为 torch.Tensor 传递给现有的 PyTorch 操作和 nn.Module 层。这对于确保相同的 PyTorch 层和张量操作可以与 XLAShardedTensor 堆叠在一起非常重要。这意味着用户不需要为分片计算重写现有的操作和模型代码。即,XLAShardedTensor 将满足以下要求:

  • XLAShardedTensortorch.Tensor 的子类,并且直接与原生 torch 操作和 module.layers 一起工作。我们使用 __torch_dispatch__XLAShardedTensor 发送到 XLA 后端。PyTorch/XLA 检索附加的分片注释以跟踪图并调用 XLA SPMDPartitioner。

  • 在内部,XLAShardedTensor(及其global_tensor输入)由XLATensor支持,该数据结构持有对分片设备数据的引用。

  • 在主机上请求时(例如,打印全局张量的值),延迟执行后的分片张量可能会被收集并物化回主机作为global_tensor。

  • 本地分片的句柄在延迟执行后严格实例化。XLAShardedTensor 暴露了 local_shards 以返回可寻址设备上的本地分片,作为 List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]

目前还在努力将XLAShardedTensor集成到DistributedTensor API中,以支持XLA后端[RFC]。

DTensor 集成

PyTorch 已经在 2.1 版本中发布了 DTensor 的原型。我们正在将 PyTorch/XLA SPMD 集成到 DTensor API 的 RFC 中。我们已经完成了 distribute_tensor 的概念验证集成,它调用 mark_sharding 注释 API 来使用 XLA 对张量及其计算进行分片:

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

此功能是实验性的,请继续关注即将发布的更多更新、示例和教程。

激活分片用于torch.compile

在2.3版本中,PyTorch/XLA添加了自定义操作dynamo_mark_sharding,该操作可用于在torch.compile区域中执行激活分片。这是我们持续努力的一部分,旨在使torch.compile + GSPMD成为使用PyTorch/XLA进行模型推理的推荐方式。使用此自定义操作的示例如下:

# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)

SPMD调试工具

我们为在TPU/GPU/CPU上使用单主机/多主机的PyTorch/XLA SPMD用户提供了一个分片放置可视化调试工具:您可以使用visualize_tensor_sharding来可视化分片张量,或者使用visualize_sharding来可视化共享字符串。以下是两个在TPU单主机(v4-8)上使用visualize_tensor_shardingvisualize_sharding的代码示例:

  • 代码片段使用了 visualize_tensor_sharding 和可视化结果:

import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
visualize_tensor_sharding example on TPU v4-8(single-host)
  • 代码片段使用了 visualize_sharding 和可视化结果:

from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
visualize_sharding example on TPU v4-8(single-host)

你可以在TPU/GPU/CPU单主机上使用这些示例,并修改它以在多主机上运行。你也可以修改它以分片风格的tiledpartial_replicationreplicated

自动分片

我们正在引入一个名为auto-sharding的新PyTorch/XLA SPMD功能,RFC。这是r2.3nightly中的一个实验性功能,支持XLA:TPU和单个TPUVM主机。

PyTorch/XLA 自动分片可以通过以下方式之一启用:

  • 设置环境变量 XLA_AUTO_SPMD=1

  • 在代码的开头调用SPMD API:

import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • 调用 pytorch.distributed._tensor.distribute_module 使用 auto-policyxla:

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

可选地,可以设置以下选项/环境变量来控制基于XLA的自动分片过程的行为:

  • XLA_AUTO_USE_GROUP_SHARDING: 参数的组重分片。默认设置。

  • XLA_AUTO_SPMD_MESH: 用于自动分片的逻辑网格形状。例如, XLA_AUTO_SPMD_MESH=2,2 对应于一个2x2的网格,包含4个全局设备。如果未设置, 将使用默认的设备网格形状 num_devices,1

分布式检查点

PyTorch/XLA SPMD 通过专用的 Planner 实例与 torch.distributed.checkpoint 库兼容。用户能够通过这个通用接口同步保存和加载检查点。

SPMDSavePlanner 和 SPMDLoadPlanner (src) 类使得 saveload 函数能够直接在 XLAShardedTensor 的分片上操作,从而在 SPMD 训练中实现分布式检查点的所有优势。

以下是同步分布式检查点API的演示:

import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# Saving a state_dict
state_dict = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)
...

# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
    "model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])

实验性的CheckpointManager接口在torch.distributed.checkpoint函数之上提供了一个更高级的API,以实现一些关键功能:

  • 托管检查点:由CheckpointManager创建的每个检查点都通过其创建时的步骤进行标识。所有跟踪的步骤都可以通过CheckpointManager.all_steps方法访问,并且可以使用CheckpointManager.restore恢复任何跟踪的步骤。

  • 异步检查点:通过CheckpointManager.save_async API创建的检查点会异步写入持久存储,以便在检查点期间不阻塞训练。在将检查点分派到后台线程之前,输入的分片状态字典首先会被移动到CPU。

  • 抢占时的自动检查点:在Cloud TPU上,可以检测到抢占并在进程终止之前进行检查点。要使用此功能,请确保您的TPU是通过启用了自动检查点的QueuedResource进行配置的,并确保在构建CheckpointManager时设置了chkpt_on_preemption参数(此选项默认启用)。

  • FSSpec 支持: CheckpointManager 使用 fsspec 存储后端,以支持直接在任何与 fsspec 兼容的文件系统(包括 GCS)上进行检查点操作。

CheckpointManager 的示例如下:

from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)

# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    # Choose the highest step
    best_step = max(tracked_steps)
    # Before restoring the checkpoint, the optimizer state must be primed
    # to allow state to be loaded into it.
    prime_optimizer(optim)
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])
    optim.load_state_dict(state_dict['optim'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
    ...
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    if chkpt_mgr.save_async(step, state_dict):
        print(f'Checkpoint taken at step {step}')

在分布式检查点中,state_dicts 是就地加载的,并且只加载检查点的所需分片。由于优化器状态是延迟创建的,状态在第一次 optimizer.step 调用之前不存在,尝试加载未初始化的优化器将会失败。

为此提供了实用方法 prime_optimizer:它通过将所有梯度设置为零并调用 optimizer.step 来运行一个虚假的训练步骤。这是一个破坏性方法,会影响模型参数和优化器状态,因此应仅在恢复之前调用。

要使用torch.distributed API(如分布式检查点),需要一个进程组。在SPMD模式下,不支持xla后端,因为编译器负责所有集合操作。

相反,必须使用诸如gloo之类的CPU进程组。在TPU上,仍然支持xla://初始化方法来发现主IP、全局世界大小和主机排名。以下是一个初始化示例:

import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr

xr.use_spmd()

# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')