Shortcuts

torch.nn.modules.lazy 的源代码

```html
import itertools
import warnings
from typing import Protocol, Optional, Type, Any

import torch
from ..parameter import is_lazy

__all__ = ['LazyModuleMixin']

class _LazyProtocol(Protocol):
    """这个类用于避免在mixin中使用mypy检查属性时出现错误。

    https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
    """

    def _register_load_state_dict_pre_hook(self, hook):
        ...

    def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
        ...

    def _lazy_load_hook(
            self, state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs):
        ...

    def _get_name(self):
        ...

    def _infer_parameters(self, module, input):
        ...

    @property
    def _parameters(self):
        ...

    @property
    def _buffers(self):
        ...

    @property
    def _non_persistent_buffers_set(self):
        ...

    @property
    def _load_hook(self):
        ...

    @property
    def _initialize_hook(self):
        ...


[docs]class LazyModuleMixin: r"""用于延迟初始化参数的模块的mixin,也称为“延迟模块”。 .. 警告: 延迟模块是一个处于积极开发中的实验性新功能,其API可能会发生变化。 延迟初始化参数的模块,或“延迟模块”, 从其forward方法的第一个输入中推导出参数的形状。 在第一次forward之前,它们包含 :class:`torch.nn.UninitializedParameter`,这些参数不应被访问或使用, 之后它们包含常规的 :class:`torch.nn.Parameter`。 延迟模块很方便,因为它们不需要计算某些 模块参数,例如典型的 :class:`torch.nn.Linear` 的 :attr:`in_features` 参数。 构造之后,带有延迟模块的网络应首先 转换为所需的dtype并放置在预期的设备上。 这是因为延迟模块仅执行形状推断,因此通常的dtype 和设备放置行为适用。 然后,延迟模块应执行“干运行”以初始化模块中的所有组件。 这些“干运行”将正确大小、dtype和设备的输入传递给 网络及其每个延迟模块。之后,网络可以像往常一样使用。 >>> # xdoctest: +SKIP >>> class LazyMLP(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.fc1 = torch.nn.LazyLinear(10) ... self.relu1 = torch.nn.ReLU() ... self.fc2 = torch.nn.LazyLinear(1) ... self.relu2 = torch.nn.ReLU() ... ... def forward(self, input): ... x = self.relu1(self.fc1(input)) ... y = self.relu2(self.fc2(x)) ... return y >>> # 构造一个带有延迟模块的网络 >>> lazy_mlp = LazyMLP() >>> # 转换网络的设备和dtype >>> # 注意:这些转换应在构造之后和任何“干运行”之前应用 >>> lazy_mlp = lazy_mlp.cuda().double() >>> lazy_mlp LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) (relu1): ReLU() (fc2): LazyLinear(in_features=0, out_features=1, bias=True) (relu2): ReLU() ) >>> # 执行干运行以初始化网络的延迟模块 >>> lazy_mlp(torch.ones(10,10).cuda()) >>> # 初始化后,LazyLinear模块成为常规的Linear模块 >>> lazy_mlp LazyMLP( (fc1): Linear(in_features=10, out_features=10, bias=True) (relu1): ReLU() (fc2): Linear(in_features=10, out_features=1, bias=True) (relu2): ReLU() ) >>> # 附加一个优化器,因为参数现在可以像往常一样使用 >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01) 使用延迟模块时的一个注意事项是网络参数的初始化顺序可能会发生变化, 因为延迟模块总是在其他模块之后初始化。 例如,如果上面定义的LazyMLP类首先有一个 :class:`torch.nn.LazyLinear` 模块, 然后是一个常规的 :class:`torch.nn.Linear` 模块,第二个模块将在构造时初始化, 而第一个模块将在第一次干运行时初始化。 这可能会导致使用延迟模块的网络的参数初始化顺序与不使用延迟模块的网络不同, 因为参数初始化的顺序通常依赖于一个有状态的随机数生成器,而顺序是不同的。 有关更多详细信息,请参阅 :doc:`/notes/randomness`。 延迟模块可以像其他模块一样使用状态字典进行序列化。例如: >>> lazy_mlp = LazyMLP() >>> # 状态字典显示未初始化的参数 >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', Uninitialized parameter), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', Uninitialized parameter), ('fc2.bias', tensor([0.0019]))]) 延迟模块可以加载常规的 :class:`torch.nn.Parameter`(即您可以序列化/反序列化 已初始化的LazyModules,它们将保持初始化状态) >>> full_mlp = LazyMLP() >>> # 干运行以初始化另一个模块 >>> full_mlp.forward(torch.ones(10, 1)) >>> # 将已初始化的状态加载到延迟模块中 >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) >>> # 状态字典现在包含有效值 >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', tensor([[-0.3837], [ 0.0907], [ 0.6708], [-0.5223], [-0.9028], [ 0.2851], [-0.4537], [ 0.6813], [ 0.5766], [-0.8678]])), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.108
优云智算