torch.distributed 的源代码
import os
import sys
from enum import Enum
import pdb
import io
import torch
[docs]def is_available() -> bool:
"""
如果分布式包可用,则返回 ``True``。
否则,
``torch.distributed`` 不会暴露任何其他API。目前,
``torch.distributed`` 在Linux、MacOS和Windows上可用。在从源代码构建PyTorch时,设置
``USE_DISTRIBUTED=1`` 以启用它。目前,Linux和Windows的默认值是 ``USE_DISTRIBUTED=1``,
MacOS的默认值是 ``USE_DISTRIBUTED=0``。
"""
return hasattr(torch._C, "_c10d_init")
if is_available() and not torch._C._c10d_init():
raise RuntimeError("Failed to initialize torch.distributed")
# 分布式包中抛出的自定义运行时错误
DistError = torch._C._DistError
DistBackendError = torch._C._DistBackendError
DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError
if is_available():
from torch._C._distributed_c10d import (
Store,
FileStore,
TCPStore,
ProcessGroup as ProcessGroup,
Backend as _Backend,
PrefixStore,
Reducer,
Logger,
BuiltinCommHookType,
GradBucket,
Work as _Work,
_DEFAULT_FIRST_BUCKET_BYTES,
_register_comm_hook,
_register_builtin_comm_hook,
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_verify_params_across_processes,
_test_python_store,
DebugLevel,
get_debug_level,
set_debug_level,
set_debug_level_from_env,
_make_nccl_premul_sum,
)
class _DistributedPdb(pdb.Pdb):
"""
支持在多进程子进程中使用PDB。
用法:
_DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
[docs] def breakpoint(rank: int = 0):
"""
设置断点,但仅在单个rank上。 所有其他rank将在你完成断点后继续。
参数:
rank (int): 在哪个rank上中断。 默认: ``0``
"""
if get_rank() == rank:
pdb = _DistributedPdb()
pdb.message(
"\n!!! 注意 !!!\n\n"
f"输入 'up' 以进入调用 dist.breakpoint(rank={rank}) 的帧\n"
)
pdb.set_trace()
barrier()
if sys.platform != "win32":
from torch._C._distributed_c10d import (
HashStore,
_round_robin_process_groups,
)
from .distributed_c10d import * # noqa: F403
# 以下划线开头的变量不会自动导入
# 请参阅 `distributed_c10d.py` 中关于 `_backend` 的注释,了解我们为何暴露此变量。
from .distributed_c10d import (
_all_gather_base,
_reduce_scatter_base,
_create_process_group_wrapper,
_rank_not_in_group,
_coalescing_manager,
_CoalescingManager,
_get_process_group_name,
)
from .rendezvous import (
rendezvous,
_create_store_from_options,
register_rendezvous_handler,
)
from .remote_device import _remote_device
set_debug_level_from_env()
else:
# 这个存根足以让
# python test/test_public_bindings.py -k test_correct_module_names
# 即使在 USE_DISTRIBUTED=0 的情况下也能工作。 可以根据需要添加更多
# 存根。
# 我们不能直接定义存根,因为它们会混淆pyre
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]