Shortcuts

torch.distributed.tensor.parallel.style 的源代码

# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple
from functools import partial

import torch
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module


__all__ = [
    "ParallelStyle",
    "RowwiseParallel",
    "SequenceParallel",
    "ColwiseParallel",
    "PrepareModuleInput",
    "PrepareModuleOutput",
]


class ParallelStyle(ABC):
    """
    并行样式契约定义了模块或子模块应如何并行化。

    它仅定义了 ``parallelize_module`` 使用的 ``apply`` 方法,这为不同类型的样式实现提供了最大的灵活性。
    """

    @abstractmethod
    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        ...


[docs]class ColwiseParallel(ParallelStyle): """ 以列方式对兼容的 nn.Module 进行分区。目前支持 nn.Linear 和 nn.Embedding。 用户可以将其与 RowwiseParallel 组合以实现更复杂模块的分片。 (例如 MLP, Attention) 关键字参数: input_layouts (Placement, 可选): 用于注释 nn.Module 输入张量的 DTensor 布局,这将用于将输入张量转换为 DTensor。如果未指定,我们假设输入张量是复制的。 output_layouts (Placement, 可选): 用于确保 nn.Module 输出的 DTensor 布局,这将用于确保 nn.Module 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上分片。 use_local_output (bool, 可选): 是否使用本地 :class:`torch.Tensor` 而不是 :class:`DTensor` 作为模块输出,默认值: True。 返回: 表示 nn.Module 列分片的 :class:`ParallelStyle` 对象。 示例:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m 是一个包含 "w1" nn.Linear 子模块的 nn.Module >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # 默认情况下,"w1" Linear 的输入将转换为复制的 DTensor >>> # 并且 "w1" 的输出将返回在最后一个维度上分片的 :class:`torch.Tensor`。 >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ... .. 注意:: 默认情况下,如果未指定 ``output_layouts``,``ColwiseParallel`` 输出将在最后一个维度上分片。如果存在需要特定张量形状的操作符(例如在配对的 ``RowwiseParallel`` 之前),请记住,如果输出被分片,操作符可能需要调整为分片大小。 """ def __init__( self, *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, use_local_output: bool = True ): super().__init__() self.input_layouts = (input_layouts or Replicate(), ) self.output_layouts = (output_layouts or Shard(-1), ) # 列线性运行时分片(期望的分片): # 1. 需要复制输入 # 2. 在最后一个维度上分片输出 self.desired_input_layouts = (Replicate(), ) self.use_local_output = use_local_output @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): # TODO: 找出对实例方法的 dynamo 支持并将其切换为实例方法 # 使用 input_layouts 注释模块输入布局/分片 input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) # 将输入布局转换为 ColwiseParallel 的期望布局 if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # 列分片权重/偏置为 Shard(0),权重为 Shard(0) # 意味着列线性输入 * 权重^T + 偏置,其中 # 权重将变为 Shard(1) for name, param in module.named_parameters(): dist_param = nn.Parameter( distribute_tensor(param, device_mesh, [Shard(0)]) ) module.register_parameter(name, dist_param) def _partition_embedding_fn(self, name, module, device_mesh): # 列分片 embedding.weight 直接为 Shard(1) for name, param in module.named_parameters(): dist
优云智算