torch.distributed.nn.api.remote_module 的源代码
#!/usr/bin/python3
import collections
import io
import sys
import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
import torch
import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch.distributed import _remote_device
from torch.distributed.rpc.internal import _internal_rpc_pickler
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
__all__ = ["RemoteModule"]
_grad_t = Union[Tuple[Tensor, ...], Tensor]
# 参见 https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self 了解使用 `T` 注释 `self` 的用法。
# `Module` 的许多方法返回 `self`,我们希望这些返回值是子类的类型,而不是 `Module` 的宽松类型。
T = TypeVar("T", bound="Module")
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
instantiator.instantiate_non_scriptable_remote_module_template()
)
_REMOTE_MODULE_PICKLED_ATTRIBUTES = (
"on",
"device",
"is_device_map_set",
"is_scriptable",
"generated_methods",
"module_rref",
)
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]
# 这些属性大多来自 RemoteModule 的父类,并且有意不进行序列化。
# RemoteModule 的新属性应位于 _REMOTE_MODULE_PICKLED_ATTRIBUTES 或 _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING 中。
# 否则,它将不会被序列化。
_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
"training",
"_parameters",
"_buffers",
"_non_persistent_buffers_set",
"_backward_hooks",
"_backward_pre_hooks",
"_is_full_backward_hook",
"_forward_hooks",
"_forward_hooks_with_kwargs",
"_forward_hooks_always_called",
"_forward_pre_hooks",
"_forward_pre_hooks_with_kwargs",
"_state_dict_hooks",
"_state_dict_pre_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
"_state_dict_pre_hooks",
"_modules",
# 以下两个属性是生成的方法,在序列化时不存在。
"forward_async",
"forward",
)
# RPC 处理器。
def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda):
instantiator.instantiate_scriptable_remote_module_template(
module_interface_cls, enable_moving_cpu_tensors_to_cuda
)
def _create_module(module_cls, args, kwargs, device):
module = module_cls(*args, **kwargs)
if not isinstance(module, nn.Module):
raise ValueError(
"期望 `module_cls(*args, **kwargs)` 返回 的实例,"
f"但它返回了 {type(module)} 的实例。"
)
module.to(device)
return module
def _create_module_with_interface(
module_cls, args, kwargs, device, module_interface_cls
):
module = _create_module(module_cls, args, kwargs, device)
if module_interface_cls is not None:
module = torch.jit.script(module)
return rpc.RRef(module, module_interface_cls)
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
ret: List[rpc.RRef[Parameter]] = []
for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param))
return ret
def _raise_not_supported(name: str) -> None:
raise ValueError(f"方法 ``{name<span class="si