Shortcuts

torch.nn.utils.skip_init

torch.nn.utils.skip_init(module_cls, *args, **kwargs)[源代码]

给定一个模块类对象和 args / kwargs,实例化模块而不初始化参数 / 缓冲区。

如果初始化过程较慢或需要执行自定义初始化,从而使默认初始化变得不必要时,这可能会很有用。由于此函数的实现方式,存在一些注意事项:

1. 该模块必须在构造函数中接受一个设备参数,该参数在构造期间传递给任何创建的参数或缓冲区。

2. 模块在其构造函数中不得对参数执行任何计算,除了初始化(即来自 torch.nn.init 的函数)。

如果这些条件得到满足,模块可以通过未初始化的参数/缓冲区值进行实例化,就像使用torch.empty()创建的一样。

Parameters
  • module_cls – 类对象;应为 torch.nn.Module 的子类

  • args – 传递给模块构造函数的参数

  • kwargs – 传递给模块构造函数的参数

Returns

已实例化的模块,参数/缓冲区未初始化

示例:

>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
包含的参数:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
       requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
包含的参数:
tensor([[-1.4677e+24,  4.5915e-41,  1.4013e-45,  0.0000e+00, -1.4677e+24,
          4.5915e-41]], requires_grad=True)