torch.distributed.tensor.parallel.style 的源代码
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple
from functools import partial
import torch
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module
__all__ = [
"ParallelStyle",
"RowwiseParallel",
"SequenceParallel",
"ColwiseParallel",
"PrepareModuleInput",
"PrepareModuleOutput",
]
class ParallelStyle(ABC):
"""
并行样式契约定义了模块或子模块应如何并行化。
它仅定义了 ``parallelize_module`` 使用的 ``apply`` 方法,这为不同类型的样式实现提供了最大的灵活性。
"""
@abstractmethod
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
...
[docs]class ColwiseParallel(ParallelStyle):
"""
以列方式对兼容的 nn.Module 进行分区。目前支持 nn.Linear 和 nn.Embedding。
用户可以将其与 RowwiseParallel 组合以实现更复杂模块的分片。
(例如 MLP, Attention)
关键字参数:
input_layouts (Placement, 可选):
用于注释 nn.Module 输入张量的 DTensor 布局,这将用于将输入张量转换为 DTensor。如果未指定,我们假设输入张量是复制的。
output_layouts (Placement, 可选):
用于确保 nn.Module 输出的 DTensor 布局,这将用于确保 nn.Module 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上分片。
use_local_output (bool, 可选):
是否使用本地 :class:`torch.Tensor` 而不是 :class:`DTensor` 作为模块输出,默认值: True。
返回:
表示 nn.Module 列分片的 :class:`ParallelStyle` 对象。
示例::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m 是一个包含 "w1" nn.Linear 子模块的 nn.Module
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # 默认情况下,"w1" Linear 的输入将转换为复制的 DTensor
>>> # 并且 "w1" 的输出将返回在最后一个维度上分片的 :class:`torch.Tensor`。
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...
.. 注意:: 默认情况下,如果未指定 ``output_layouts``,``ColwiseParallel`` 输出将在最后一个维度上分片。如果存在需要特定张量形状的操作符(例如在配对的 ``RowwiseParallel`` 之前),请记住,如果输出被分片,操作符可能需要调整为分片大小。
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(), )
self.output_layouts = (output_layouts or Shard(-1), )
# 列线性运行时分片(期望的分片):
# 1. 需要复制输入
# 2. 在最后一个维度上分片输出
self.desired_input_layouts = (Replicate(), )
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# TODO: 找出对实例方法的 dynamo 支持并将其切换为实例方法
# 使用 input_layouts 注释模块输入布局/分片
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
# 将输入布局转换为 ColwiseParallel 的期望布局
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def _partition_linear_fn(self, name, module, device_mesh):
# 列分片权重/偏置为 Shard(0),权重为 Shard(0)
# 意味着列线性输入 * 权重^T + 偏置,其中
# 权重将变为 Shard(1)
for name, param in module.named_parameters():
dist_param = nn.Parameter(
distribute_tensor(param, device_mesh, [Shard(0)])
)
module.register_parameter(name, dist_param)
def _partition_embedding_fn(self, name, module, device_mesh):
# 列分片 embedding.weight 直接为 Shard(1)
for name, param in module.named_parameters():
dist