Shortcuts

torch.distributed.pipeline.sync.skip.skippable 的源代码

# 版权 2019 Kakao Brain
#
# 版权所有 (c) Facebook, Inc. 及其附属公司。保留所有权利。
#
# 本源代码根据在
# 本源代码树的根目录中的LICENSE文件中找到的BSD许可证授权。
"""定义跳过连接的用户界面。"""
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    ClassVar,
    Dict,
    FrozenSet,
    Generator,
    Iterable,
    List,
    Optional,
    Set,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

from torch import Tensor, nn

from ..microbatch import Batch
from .namespace import Namespace
from .tracker import current_skip_tracker

__all__ = ["skippable", "stash", "pop", "verify_skippables"]


Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

StashPop = Union["stash", "pop"]
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
if TYPE_CHECKING:
    # 类型检查:nn.Module不是泛型
    SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]  # type: ignore[type-arg]
else:
    SkippableModule = nn.Module

T = TypeVar("T", bound="Skippable")


class Skippable(nn.Module):
    """可跳过模块的基类。

    不要直接使用此类。请通过 :func:`skippable` 定义子类。

    """

    module_cls: ClassVar[Type[SkippableModule]]
    stashable_names: ClassVar[FrozenSet[str]]
    poppable_names: ClassVar[FrozenSet[str]]

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()
        self.module = self.module_cls(*args, **kwargs)  # type: ignore[call-arg]
        self.namespaces: Dict[str, Namespace] = {}

    def __repr__(self) -> str:
        return f"@skippable({self.module})"

    def namespaced(self, name: str) -> Tuple[Namespace, str]:
        """为给定的跳过名称添加命名空间前缀。"""
        ns = self.namespaces.get(name)
        ns = cast(Namespace, ns)
        return (ns, name)

    def stashable(self) -> Iterable[Tuple[Namespace, str]]:
        """迭代要存储的命名空间跳过名称。"""
        for name in self.stashable_names:
            yield self.namespaced(name)

    def poppable(self) -> Iterable[Tuple[Namespace, str]]:
        """迭代要弹出的命名空间跳过名称。"""
        for name in self.poppable_names:
            yield self.namespaced(name)

    def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
        """隔离指定的子集或整个跳过张量集。

        在单个顺序模块中,除非通过不同的命名空间隔离,否则不允许具有相同名称的跳过张量。

        以下是使用相同名称两次的示例。每对 ``Layer1`` 和 ``Layer2`` 都通过其自己的命名空间 ``ns1`` 和 ``ns2`` 隔离。不再有冲突::

            ns1 = Namespace()
            ns2 = Namespace()

            model = nn.Sequential(
                Layer1().isolate(ns1),
                Layer1().isolate(ns2),
                Layer2(),
                Layer3().isolate(ns2),
                Layer3().isolate(ns1),
            )

        当 `only` 参数被省略时,所有跳过张量都被隔离。您可以通过传递 `only` 参数来隔离跳过张量的子集::

            ns_alice = Namespace()
            ns_bob = Namespace()

            model = nn.Sequential(
                ...
                StashStashPop().isolate(ns_alice, only=['alice'])