torch.distributed.pipeline.sync.pipe 的源代码
# 版权 2019 Kakao Brain
#
# 版权所有 (c) Facebook, Inc. 及其附属公司。保留所有权利。
#
# 此源代码根据在
# 源代码树的根目录中的 LICENSE 文件中找到的 BSD 许可证授权。
"""Pipe 接口。"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast
import torch
from torch import Tensor, nn
from torch.distributed.rpc import RRef
import torch.autograd
import torch.cuda
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
if TYPE_CHECKING:
# 类型检查:nn.Module 不是泛型
Module = nn.Module[TensorOrTensors] # type: ignore[type-arg]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def _recommend_auto_balance(message: str) -> str:
"""通过推荐使用 :mod:`torchpipe.balance` 扩展消息。"""
return f"""{message}
如果您的模型仍在开发中,其最佳平衡会频繁变化。在这种情况下,我们强烈推荐使用 'torch.distributed.pipeline.sync.balance' 进行
简单的自动平衡:
from torch.distributed.pipeline.sync import Pipe
from torch.distributed.pipeline.sync.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = Pipe(model, balance, ...)
"""
def _verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError("module 必须是 nn.Sequential 才能进行分区")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("不支持具有重复子模块的模块")
def _verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("不支持在不同设备上具有重复参数的模块")
class BalanceError(ValueError):
pass
def _retrieve_device(module: nn.Module) -> torch.device:
"""验证模块中的所有参数具有相同的设备并返回适当的设备。
参数:
要处理的 ``nn.Module``。
返回:
整个模块的 ``torch.Device``。
引发:
ValueError:
如果 ``nn.Module`` 参数的设备不相同。
"""
device</span