torch.nn.utils.init 的源代码
import inspect
import torch
[docs]def skip_init(module_cls, *args, **kwargs):
r"""
给定一个模块类对象和参数/关键字参数,实例化模块而不初始化参数/缓冲区。
如果初始化过程较慢或将执行自定义初始化,使得默认初始化变得不必要,这可能会很有用。由于此函数的实现方式,存在一些注意事项:
1. 模块的构造函数必须接受一个`device`参数,该参数传递给在构造过程中创建的任何参数或缓冲区。
2. 模块的构造函数不得对参数执行任何计算,除了初始化(即来自 :mod:`torch.nn.init` 的函数)。
如果满足这些条件,模块可以实例化,参数/缓冲区的值未初始化,就像使用 :func:`torch.empty` 创建的一样。
参数:
module_cls: 类对象;应该是 :class:`torch.nn.Module` 的子类
args: 传递给模块构造函数的参数
kwargs: 传递给模块构造函数的关键字参数
返回:
实例化的模块,参数/缓冲区未初始化
示例::
>>> # xdoctest: +IGNORE_WANT("非确定性")
>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
包含的参数:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
包含的参数:
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
4.5915e-41]], requires_grad=True)
"""
if not issubclass(module_cls, torch.nn.Module):
raise RuntimeError(f'Expected a Module; got {module_cls}')
if 'device' not in inspect.signature(module_cls).parameters:
raise RuntimeError('Module must support a \'device\' arg to skip initialization')
final_device = kwargs.pop('device', 'cpu')
kwargs['device'] = 'meta'
return module_cls(*args, **kwargs).to_empty(device=final_device)