Shortcuts

torch.distributed.device_mesh 的源代码

# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
import logging
import math
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch

from torch.distributed import is_available

from ..utils._typing_utils import not_none

__all__ = ["init_device_mesh", "DeviceMesh"]


if not is_available():
    import sys

    # 当分布式不可用时,我们需要创建存根。
    # 否则,我们会失败文档测试(```./.ci/pytorch/docs-test.sh```),
    # 因为它会尝试导入 ``torch.distributed.device_mesh`` 或
    # ``torch.distributed.init_device_mesh`` 但找不到它们。

    class _DeviceMeshStub:
        pass

    def _init_device_mesh_stub():
        pass

    sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub  # type: ignore[attr-defined]
    sys.modules[
        "torch.distributed.device_mesh"
    ].init_device_mesh = _init_device_mesh_stub  # type: ignore[attr-defined]


else:
    from torch.distributed.distributed_c10d import (
        _find_pg_by_ranks_and_tag,
        _get_default_group,
        _get_group_tag,
        get_rank,
        get_world_size,
        init_process_group,
        is_initialized,
        new_group,
        ProcessGroup,
    )

    logger = logging.getLogger(__name__)

    # 仅在类型检查时导入 numpy 类型
    if TYPE_CHECKING:
        try:
            from numpy.typing import ArrayLike
        except ImportError:
            logger.warning(
                "DeviceMesh 需要 numpy >= 1.21 安装以进行类型检查"
            )

    class _MeshEnv:
        def __init__(self) -> None:
            self.mesh_stack: List[DeviceMesh] = []
            self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}

        def get_current_mesh(self) -> "DeviceMesh":
            if len(self.mesh_stack) == 0:
                raise RuntimeError("当前没有活动的设备网格!")
            return self.mesh_stack[-1]

        def create_child_mesh(
            self,
            device_mesh: "DeviceMesh",
            mesh_dim_names: Tuple[str],
        ) -> "DeviceMesh":
            # 将当前维度交换到最后一个维度,然后重塑以展平其他维度,
            # 这样我们就可以只提取包含 cur_rank 的排名列表。
            mesh_dims = [
                not_none(device_mesh.mesh_dim_names).index(mesh_dim_name)
                for mesh_dim_name in mesh_dim_names
            ]
            cur_rank = device_mesh.get_rank()
            mesh = device_mesh.mesh
            all_mesh_dims = list(range(mesh.ndim))
            for mesh_dim in mesh_dims:
                # 删除而不是弹出,因为我们想要删除的索引的值,而不是它在列表中的位置
                # 因为这个列表是动态变化的。
                all_mesh_dims.remove(mesh_dim)

            mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims]

            pg_ranks_by_dim = device_mesh.mesh.permute(
                *all_mesh_dims, *mesh_dims
            ).reshape(-1, *mesh_sizes)

            for mesh_nd in pg_ranks_by_dim:
                if cur_rank in mesh_nd:
                    sub_mesh = DeviceMesh(
                        device_mesh.device_type,
                        mesh_nd,
                        mesh_dim_names=mesh_dim_names,
                    )
                    res_sub_mesh = sub_mesh

            res_sub_mesh._dim_group_infos = [  # type: ignore[possibly-undefined]
                device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims
            ]

            # 将当前 DeviceMesh 指定为子 DeviceMesh 的父级。
            self.child_to_parent_mapping[res_sub_mesh] = device_mesh
            return res_sub_mesh

        def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
            return self.child_to_parent_mapping<span class="