torch.distributed.device_mesh 的源代码
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
import logging
import math
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from ..utils._typing_utils import not_none
__all__ = ["init_device_mesh", "DeviceMesh"]
if not is_available():
import sys
# 当分布式不可用时,我们需要创建存根。
# 否则,我们会失败文档测试(```./.ci/pytorch/docs-test.sh```),
# 因为它会尝试导入 ``torch.distributed.device_mesh`` 或
# ``torch.distributed.init_device_mesh`` 但找不到它们。
class _DeviceMeshStub:
pass
def _init_device_mesh_stub():
pass
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
sys.modules[
"torch.distributed.device_mesh"
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
else:
from torch.distributed.distributed_c10d import (
_find_pg_by_ranks_and_tag,
_get_default_group,
_get_group_tag,
get_rank,
get_world_size,
init_process_group,
is_initialized,
new_group,
ProcessGroup,
)
logger = logging.getLogger(__name__)
# 仅在类型检查时导入 numpy 类型
if TYPE_CHECKING:
try:
from numpy.typing import ArrayLike
except ImportError:
logger.warning(
"DeviceMesh 需要 numpy >= 1.21 安装以进行类型检查"
)
class _MeshEnv:
def __init__(self) -> None:
self.mesh_stack: List[DeviceMesh] = []
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
def get_current_mesh(self) -> "DeviceMesh":
if len(self.mesh_stack) == 0:
raise RuntimeError("当前没有活动的设备网格!")
return self.mesh_stack[-1]
def create_child_mesh(
self,
device_mesh: "DeviceMesh",
mesh_dim_names: Tuple[str],
) -> "DeviceMesh":
# 将当前维度交换到最后一个维度,然后重塑以展平其他维度,
# 这样我们就可以只提取包含 cur_rank 的排名列表。
mesh_dims = [
not_none(device_mesh.mesh_dim_names).index(mesh_dim_name)
for mesh_dim_name in mesh_dim_names
]
cur_rank = device_mesh.get_rank()
mesh = device_mesh.mesh
all_mesh_dims = list(range(mesh.ndim))
for mesh_dim in mesh_dims:
# 删除而不是弹出,因为我们想要删除的索引的值,而不是它在列表中的位置
# 因为这个列表是动态变化的。
all_mesh_dims.remove(mesh_dim)
mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims]
pg_ranks_by_dim = device_mesh.mesh.permute(
*all_mesh_dims, *mesh_dims
).reshape(-1, *mesh_sizes)
for mesh_nd in pg_ranks_by_dim:
if cur_rank in mesh_nd:
sub_mesh = DeviceMesh(
device_mesh.device_type,
mesh_nd,
mesh_dim_names=mesh_dim_names,
)
res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [ # type: ignore[possibly-undefined]
device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims
]
# 将当前 DeviceMesh 指定为子 DeviceMesh 的父级。
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
return self.child_to_parent_mapping<span class="