Shortcuts

torch.distributed.rpc.api 的源代码

__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
           "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]

import collections
import contextlib
import functools
import inspect
import logging
import threading
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING

import torch
from torch.futures import Future

from torch._C._distributed_rpc import (
    PyRRef,
    RemoteProfilerManager,
    WorkerInfo,
    TensorPipeAgent,
    get_rpc_timeout,
    _cleanup_python_rpc_handler,
    _delete_all_user_and_unforked_owner_rrefs,
    _destroy_rref_context,
    _get_current_rpc_agent,
    _invoke_remote_builtin,
    _invoke_remote_python_udf,
    _invoke_remote_torchscript,
    _invoke_rpc_builtin,
    _invoke_rpc_python_udf,
    _invoke_rpc_torchscript,
    _is_current_rpc_agent_set,
    _reset_current_rpc_agent,
    _set_and_start_rpc_agent,
)

from .internal import (
    PythonUDF,
    RPCExecMode,
    _internal_rpc_pickler,
    _build_rpc_profiling_key,
)

from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT

from ._utils import _group_membership_management, _update_group_membership

logger = logging.getLogger(__name__)

# NB: 在关闭期间忽略RRef泄漏。如果没有这个,应用程序必须
# 确保应用程序代码中没有任何对任何RRef的引用,并且
# Python GC已经完成了删除这些RRef的工作。这可能会导致
# 特别是在大型应用程序中的调试体验不佳。因此,默认情况下,
# 我们将在关闭期间忽略RRef泄漏。这通常是
# 可以接受的,因为关闭意味着应用程序已经完成了训练,不再关心
# 状态。
#
# 要启用RRef泄漏检查,请将此_ignore_rref_leak设置为False
_ignore_rref_leak = True
_default_pickler = _internal_rpc_pickler

@contextlib.contextmanager
def _use_rpc_pickler(rpc_pickler):
    r"""
    rpc_pickler: (.internal._InternalRPCPickler) 覆盖默认的RPC pickler
    """
    global _default_pickler
    _default_pickler = rpc_pickler
    try:
        yield
    finally:
        _default_pickler = _internal_rpc_pickler


def _require_initialized(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not _is_current_rpc_agent_set():
            raise RuntimeError(
                "RPC尚未初始化。请先调用"
                "torch.distributed.rpc.init_rpc。"
            )
        return func(*args, **kwargs)

    return wrapper


class AllGatherStates:
    def __init__(self):
        # 每个`gathered_objects`最初是一个空字典。
        # 领导者工作线程被选为排序后的工作线程名称列表中的第一个工作线程。
        # 每当有工作线程进入`_all_gather()`时,它
        # 在领导者上运行`_gather_to_leader()`,将自己的名称和
        # 数据对象添加到此字典中。领导者也会在调用`_all_gather()`时
        # 将自己的名称添加到字典中。
        # 一旦`set(gathered_objects.keys()) == _ALL_WORKER_NAMES`,领导者
        # 将广播收集到的字典给所有跟随者工作线程,并设置他们的
        # `gathered_objects`字段和`proceed_signal`字段。
        self.gathered_objects = {}
        # 所有工作线程都在此信号上等待,直到接收到所有收集到的
        # 对象。
        self.proceed_signal = threading.Event()


# `def _all_gather()`使用的状态。
# `_ALL_WORKER_NAMES`在初始化RPC层时初始化。
_ALL_WORKER_NAMES: Set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id: Dict[str, int] = {}
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)


def _init_rpc_states(agent):
    worker_infos = agent.get_worker_infos()
    global _ALL_WORKER_NAMES
    _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}

    # NB: 后端实现可能已经设置了rpc_agent。
    if not _is_current_rpc_agent_set():
        _set_and_start_rpc_agent(agent)


def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
    with _all_gather_dict_lock:
        if not worker_names:
            worker_names = _ALL_WORKER_NAMES
            assert (
                worker_name in worker_names
            ), f"{worker_name} 不是领导者期望的。"
        states