Shortcuts

torch.distributed.pipeline.sync.pipe 的源代码

# 版权 2019 Kakao Brain
#
# 版权所有 (c) Facebook, Inc. 及其附属公司。保留所有权利。
#
# 此源代码根据在
# 源代码树的根目录中的 LICENSE 文件中找到的 BSD 许可证授权。
"""Pipe 接口。"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast

import torch
from torch import Tensor, nn
from torch.distributed.rpc import RRef
import torch.autograd
import torch.cuda

from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream

__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]


Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]

Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

if TYPE_CHECKING:
    # 类型检查:nn.Module 不是泛型
    Module = nn.Module[TensorOrTensors]  # type: ignore[type-arg]
    NamedModules = OrderedDict[str, Module]
else:
    Module = nn.Module
    NamedModules = OrderedDict


def _recommend_auto_balance(message: str) -> str:
    """通过推荐使用 :mod:`torchpipe.balance` 扩展消息。"""
    return f"""{message}

如果您的模型仍在开发中,其最佳平衡会频繁变化。在这种情况下,我们强烈推荐使用 'torch.distributed.pipeline.sync.balance' 进行
简单的自动平衡:

  from torch.distributed.pipeline.sync import Pipe
  from torch.distributed.pipeline.sync.balance import balance_by_time

  partitions = torch.cuda.device_count()
  sample = torch.empty(...)
  balance = balance_by_time(partitions, model, sample)

  model = Pipe(model, balance, ...)
"""


def _verify_module(module: nn.Sequential) -> None:
    if not isinstance(module, nn.Sequential):
        raise TypeError("module 必须是 nn.Sequential 才能进行分区")

    named_children = list(module.named_children())
    if len(named_children) != len(module):
        raise ValueError("不支持具有重复子模块的模块")


def _verify_splitting(
    module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
) -> None:
    num_parameters = len(list(module.parameters()))
    num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
    if num_parameters == num_child_parameters:
        return

    for i in range(len(partitions)):
        for j in range(i + 1, len(partitions)):
            parti = partitions[i]
            partj = partitions[j]
            if devices[i] == devices[j]:
                continue
            for p in parti.parameters():
                for q in partj.parameters():
                    if p is q:
                        raise ValueError("不支持在不同设备上具有重复参数的模块")


class BalanceError(ValueError):
    pass


def _retrieve_device(module: nn.Module) -> torch.device:
    """验证模块中的所有参数具有相同的设备并返回适当的设备。

    参数:
        要处理的 ``nn.Module``。

    返回:
        整个模块的 ``torch.Device``。

    引发:
        ValueError:
            如果 ``nn.Module`` 参数的设备不相同。
    """

    device</span