torch.distributed.algorithms.join 的源代码
```html
import warnings from abc import ABC, abstractmethod from types import TracebackType from typing import Any, List, NamedTuple, Optional, Type import torch import torch.distributed as dist __all__ = ['JoinHook', 'Joinable', 'Join'][docs]class JoinHook: r""" 这定义了一个join钩子,它在join上下文管理器中提供了两个入口点。 入口点:一个主钩子,在存在未加入的进程时重复调用,以及一个后钩子,在所有进程加入后调用一次。 要为通用join上下文管理器实现join钩子,定义一个继承自:class:`JoinHook`的类,并根据需要重写``main_hook()``和``post_hook()``。 """[docs] def main_hook(self) -> None: r"""在存在未加入的进程时调用此钩子以在训练迭代中掩盖集体通信。 训练迭代即在一个前向传递、后向传递和优化器步骤中。 """ ...[docs] def post_hook(self, is_last_joiner: bool) -> None: r""" 在所有进程加入后调用钩子。 它传递一个额外的``bool``参数``is_last_joiner``,指示该rank是否是最后加入的之一。 参数: is_last_joiner (bool): 如果该rank是最后加入的之一,则为``True``;否则为``False``。 """ ...[docs]class Joinable(ABC): r""" 这定义了一个可加入类的抽象基类。 一个可加入的类 (继承自:class:`Joinable`)应实现:meth:`join_hook`, 返回一个:class:`JoinHook`实例,此外还应实现 :meth:`join_device`和:meth:`join_process_group`,分别返回设备和 进程组信息。 """ @abstractmethod def __init__(self): super().__init__() self._join_config = _JoinConfig.construct_disabled_join_config()class _JoinConfig(NamedTuple): r"""这包括join上下文管理器端所需的:class:`Joinable`实例的所有字段。""" enable: bool throw_on_early_termination: bool is_first_joinable: bool @staticmethod def construct_disabled_join_config(): r"""返回一个:class:`_JoinConfig`实例,指示应禁用与join相关的逻辑。 例如,如果调用者不在join上下文管理器中。 """ return _JoinConfig( enable=False, throw_on_early_termination=False, is_first_joinable=False )[docs] @abstractmethod def join_hook(self, **kwargs) -> JoinHook: r""" 返回给定:class:`Joinable`的:class:`JoinHook`实例。 参数: kwargs (dict): 包含任何关键字参数的字典 在运行时修改join钩子的行为;所有 共享相同join上下文管理器的:class:`Joinable`实例 都会转发相同的``kwargs``值。 """ ...@property @abstractmethod def join_device(self) -> torch.device: r"""返回执行join上下文管理器所需的集体通信的设备。""" ... @property @abstractmethod def join_process_group(self) -> Any: r"""返回join上下文管理器本身所需的集体通信的进程组。""" ...