torch.distributed.pipeline.sync.skip.skippable 的源代码
# 版权 2019 Kakao Brain
#
# 版权所有 (c) Facebook, Inc. 及其附属公司。保留所有权利。
#
# 本源代码根据在
# 本源代码树的根目录中的LICENSE文件中找到的BSD许可证授权。
"""定义跳过连接的用户界面。"""
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Set,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from torch import Tensor, nn
from ..microbatch import Batch
from .namespace import Namespace
from .tracker import current_skip_tracker
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
StashPop = Union["stash", "pop"]
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
if TYPE_CHECKING:
# 类型检查:nn.Module不是泛型
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg]
else:
SkippableModule = nn.Module
T = TypeVar("T", bound="Skippable")
class Skippable(nn.Module):
"""可跳过模块的基类。
不要直接使用此类。请通过 :func:`skippable` 定义子类。
"""
module_cls: ClassVar[Type[SkippableModule]]
stashable_names: ClassVar[FrozenSet[str]]
poppable_names: ClassVar[FrozenSet[str]]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg]
self.namespaces: Dict[str, Namespace] = {}
def __repr__(self) -> str:
return f"@skippable({self.module})"
def namespaced(self, name: str) -> Tuple[Namespace, str]:
"""为给定的跳过名称添加命名空间前缀。"""
ns = self.namespaces.get(name)
ns = cast(Namespace, ns)
return (ns, name)
def stashable(self) -> Iterable[Tuple[Namespace, str]]:
"""迭代要存储的命名空间跳过名称。"""
for name in self.stashable_names:
yield self.namespaced(name)
def poppable(self) -> Iterable[Tuple[Namespace, str]]:
"""迭代要弹出的命名空间跳过名称。"""
for name in self.poppable_names:
yield self.namespaced(name)
def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
"""隔离指定的子集或整个跳过张量集。
在单个顺序模块中,除非通过不同的命名空间隔离,否则不允许具有相同名称的跳过张量。
以下是使用相同名称两次的示例。每对 ``Layer1`` 和 ``Layer2`` 都通过其自己的命名空间 ``ns1`` 和 ``ns2`` 隔离。不再有冲突::
ns1 = Namespace()
ns2 = Namespace()
model = nn.Sequential(
Layer1().isolate(ns1),
Layer1().isolate(ns2),
Layer2(),
Layer3().isolate(ns2),
Layer3().isolate(ns1),
)
当 `only` 参数被省略时,所有跳过张量都被隔离。您可以通过传递 `only` 参数来隔离跳过张量的子集::
ns_alice = Namespace()
ns_bob = Namespace()
model = nn.Sequential(
...
StashStashPop().isolate(ns_alice, only=['alice'])