Shortcuts

torch.distributed 的源代码

import os
import sys
from enum import Enum
import pdb
import io

import torch

[docs]def is_available() -> bool: """ 如果分布式包可用,则返回 ``True``。 否则, ``torch.distributed`` 不会暴露任何其他API。目前, ``torch.distributed`` 在Linux、MacOS和Windows上可用。在从源代码构建PyTorch时,设置 ``USE_DISTRIBUTED=1`` 以启用它。目前,Linux和Windows的默认值是 ``USE_DISTRIBUTED=1``, MacOS的默认值是 ``USE_DISTRIBUTED=0``。 """ return hasattr(torch._C, "_c10d_init")
if is_available() and not torch._C._c10d_init(): raise RuntimeError("Failed to initialize torch.distributed") # 分布式包中抛出的自定义运行时错误 DistError = torch._C._DistError DistBackendError = torch._C._DistBackendError DistNetworkError = torch._C._DistNetworkError DistStoreError = torch._C._DistStoreError if is_available(): from torch._C._distributed_c10d import ( Store, FileStore, TCPStore, ProcessGroup as ProcessGroup, Backend as _Backend, PrefixStore, Reducer, Logger, BuiltinCommHookType, GradBucket, Work as _Work, _DEFAULT_FIRST_BUCKET_BYTES, _register_comm_hook, _register_builtin_comm_hook, _broadcast_coalesced, _compute_bucket_assignment_by_size, _verify_params_across_processes, _test_python_store, DebugLevel, get_debug_level, set_debug_level, set_debug_level_from_env, _make_nccl_premul_sum, ) class _DistributedPdb(pdb.Pdb): """ 支持在多进程子进程中使用PDB。 用法: _DistributedPdb().set_trace() """ def interaction(self, *args, **kwargs): _stdin = sys.stdin try: sys.stdin = open('/dev/stdin') pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin
[docs] def breakpoint(rank: int = 0): """ 设置断点,但仅在单个rank上。 所有其他rank将在你完成断点后继续。 参数: rank (int): 在哪个rank上中断。 默认: ``0`` """ if get_rank() == rank: pdb = _DistributedPdb() pdb.message( "\n!!! 注意 !!!\n\n" f"输入 'up' 以进入调用 dist.breakpoint(rank={rank}) 的帧\n" ) pdb.set_trace() barrier()
if sys.platform != "win32": from torch._C._distributed_c10d import ( HashStore, _round_robin_process_groups, ) from .distributed_c10d import * # noqa: F403 # 以下划线开头的变量不会自动导入 # 请参阅 `distributed_c10d.py` 中关于 `_backend` 的注释,了解我们为何暴露此变量。 from .distributed_c10d import ( _all_gather_base, _reduce_scatter_base, _create_process_group_wrapper, _rank_not_in_group, _coalescing_manager, _CoalescingManager, _get_process_group_name, ) from .rendezvous import ( rendezvous, _create_store_from_options, register_rendezvous_handler, ) from .remote_device import _remote_device set_debug_level_from_env() else: # 这个存根足以让 # python test/test_public_bindings.py -k test_correct_module_names # 即使在 USE_DISTRIBUTED=0 的情况下也能工作。 可以根据需要添加更多 # 存根。 # 我们不能直接定义存根,因为它们会混淆pyre class _ProcessGroupStub: pass sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]