Shortcuts

torch.distributed.nn.api.remote_module 的源代码

#!/usr/bin/python3
import collections
import io
import sys
import types
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import torch
import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch.distributed import _remote_device
from torch.distributed.rpc.internal import _internal_rpc_pickler
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle

__all__ = ["RemoteModule"]

_grad_t = Union[Tuple[Tensor, ...], Tensor]
# 参见 https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self 了解使用 `T` 注释 `self` 的用法。
# `Module` 的许多方法返回 `self`,我们希望这些返回值是子类的类型,而不是 `Module` 的宽松类型。
T = TypeVar("T", bound="Module")

_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
    instantiator.instantiate_non_scriptable_remote_module_template()
)

_REMOTE_MODULE_PICKLED_ATTRIBUTES = (
    "on",
    "device",
    "is_device_map_set",
    "is_scriptable",
    "generated_methods",
    "module_rref",
)

_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES)  # type: ignore[misc]

# 这些属性大多来自 RemoteModule 的父类,并且有意不进行序列化。
# RemoteModule 的新属性应位于 _REMOTE_MODULE_PICKLED_ATTRIBUTES 或 _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING 中。
# 否则,它将不会被序列化。
_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
    "training",
    "_parameters",
    "_buffers",
    "_non_persistent_buffers_set",
    "_backward_hooks",
    "_backward_pre_hooks",
    "_is_full_backward_hook",
    "_forward_hooks",
    "_forward_hooks_with_kwargs",
    "_forward_hooks_always_called",
    "_forward_pre_hooks",
    "_forward_pre_hooks_with_kwargs",
    "_state_dict_hooks",
    "_state_dict_pre_hooks",
    "_load_state_dict_pre_hooks",
    "_load_state_dict_post_hooks",
    "_state_dict_pre_hooks",
    "_modules",
    # 以下两个属性是生成的方法,在序列化时不存在。
    "forward_async",
    "forward",
)


# RPC 处理器。
def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda):
    instantiator.instantiate_scriptable_remote_module_template(
        module_interface_cls, enable_moving_cpu_tensors_to_cuda
    )


def _create_module(module_cls, args, kwargs, device):
    module = module_cls(*args, **kwargs)
    if not isinstance(module, nn.Module):
        raise ValueError(
            "期望 `module_cls(*args, **kwargs)` 返回  的实例,"
            f"但它返回了 {type(module)} 的实例。"
        )
    module.to(device)
    return module


def _create_module_with_interface(
    module_cls, args, kwargs, device, module_interface_cls
):
    module = _create_module(module_cls, args, kwargs, device)
    if module_interface_cls is not None:
        module = torch.jit.script(module)
    return rpc.RRef(module, module_interface_cls)


def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
    ret: List[rpc.RRef[Parameter]] = []
    for param in module_rref.local_value().parameters(recurse):
        ret.append(rpc.RRef(param))
    return ret


def _raise_not_supported(name: str) -> None:
    raise ValueError(f"方法 ``{name<span class="si