torch.distributed.tensor.parallel.api 的源代码
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
from typing import Dict, Union
import torch
import torch.distributed._tensor.random as random
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
)
from torch.distributed._tensor.random import (
is_rng_supported_mesh,
TensorParallelRNGTracker,
)
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import (
ParallelStyle,
)
__all__ = [
"parallelize_module",
]
[docs]def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: DeviceMesh,
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
) -> nn.Module:
"""
在 PyTorch 中通过并行化模块或子模块来应用张量并行。
我们根据用户指定的计划并行化模块或子模块。parallelize_plan 包含
:class:`ParallelStyle`,指示用户希望如何并行化模块或子模块。
用户还可以为每个模块的完全限定名称 (FQN) 指定不同的并行风格。
请注意,``parallelize_module`` 仅接受 1 维的 :class:`DeviceMesh`,如果您有一个 2 维或 N 维的 :class:`DeviceMesh`,
请先将其切片为 1 维的子 DeviceMesh,然后传递给此 API(即 ``device_mesh["tp"]``)
参数:
module (:class:`nn.Module`):
要并行化的模块。
device_mesh (:class:`DeviceMesh`):
描述 DTensor 的设备网格拓扑的对象。
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
用于并行化模块的计划。它可以是
:class:`ParallelStyle` 对象,包含如何为张量并行准备输入/输出,或者它可以是
包含模块 FQN 及其对应 :class:`ParallelStyle` 对象的字典。
返回:
一个并行化的 :class:`nn.Module` 对象。
示例::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # 定义模块。
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>
.. 注意:: 对于像 Attention、MLP 层这样的复杂模块架构,我们建议将不同的 ParallelStyles 组合在一起(例如 ``ColwiseParallel`` 和 ``RowwiseParallel``)并传递
作为 parallelize_plan,以实现所需的拆分计算。
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
_validate_tp_mesh_dim(device_mesh)
# 如果还没有 TP RNG 状态跟踪器,则实例化一个
if is_rng_supported_mesh(device_mesh) and not isinstance(
random._rng_tracker, TensorParallelRNGTracker
):
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
# TODO: 我们应该允许用户从配置中传递默认种子
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
# 默认情况下,我们在非张量并行区域中执行随机操作。如果用户希望
# 在张量并行区域中执行,他们可以手动将此字段设置为 True
# 在并行化模型之后。
random._rng_tracker.distribute_region_enabled = False
if isinstance(parallelize_plan, ParallelStyle):
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
sub_module = module.get_submodule(module_path)
parent_module = module
if "." in module_path:
parent_module_path = ".".join(module_path.split(".")[:-1])
parent_module = module.get_submodule(parent_module_path)
module_path = module_path.split(".")[-1]
parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
module_path,
parallelize_module( # type: ignore[arg-type]
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
),
)
return module
else:
raise RuntimeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)