LazyModuleMixin¶
- class torch.nn.modules.lazy.LazyModuleMixin(*args, **kwargs)[源代码]¶
用于延迟初始化参数的模块的混合器,也称为“延迟模块”。
延迟初始化参数的模块,或称为“延迟模块”, 从其前向方法的第一个输入中推导出参数的形状。 在第一个前向传播之前,它们包含不应访问或使用的
torch.nn.UninitializedParameter
, 之后它们包含常规的torch.nn.Parameter
。 延迟模块很方便,因为它们不需要计算某些模块参数, 例如典型torch.nn.Linear
的in_features
参数。构建后,带有惰性模块的网络应首先转换为所需的dtype并放置在预期的设备上。这是因为惰性模块仅执行形状推断,因此通常的dtype和设备放置行为适用。然后,惰性模块应执行“空运行”以初始化模块中的所有组件。这些“空运行”通过网络发送具有正确大小、dtype和设备的输入,并传递给每个惰性模块。在此之后,网络可以像往常一样使用。
>>> class LazyMLP(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.fc1 = torch.nn.LazyLinear(10) ... self.relu1 = torch.nn.ReLU() ... self.fc2 = torch.nn.LazyLinear(1) ... self.relu2 = torch.nn.ReLU() ... ... def forward(self, input): ... x = self.relu1(self.fc1(input)) ... y = self.relu2(self.fc2(x)) ... return y >>> # 使用惰性模块构建网络 >>> lazy_mlp = LazyMLP() >>> # 转换网络的设备和数据类型 >>> # 注意:这些转换可以在构建后和任何“试运行”之前应用 >>> lazy_mlp = lazy_mlp.cuda().double() >>> lazy_mlp LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) (relu1): ReLU() (fc2): LazyLinear(in_features=0, out_features=1, bias=True) (relu2): ReLU() ) >>> # 执行试运行以初始化网络的惰性模块 >>> lazy_mlp(torch.ones(10,10).cuda()) >>> # 初始化后,LazyLinear 模块变为常规的 Linear 模块 >>> lazy_mlp LazyMLP( (fc1): Linear(in_features=10, out_features=10, bias=True) (relu1): ReLU() (fc2): Linear(in_features=10, out_features=1, bias=True) (relu2): ReLU() ) >>> # 附加一个优化器,因为参数现在可以像往常一样使用 >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
使用惰性模块时需要注意的一点是,网络参数的初始化顺序可能会发生变化,因为惰性模块总是在其他模块之后进行初始化。例如,如果上面定义的 LazyMLP 类首先包含一个
torch.nn.LazyLinear
模块,然后是一个常规的torch.nn.Linear
模块,第二个模块将在构造时初始化,而第一个模块将在第一次试运行时初始化。这可能导致使用惰性模块的网络的参数初始化方式与不使用惰性模块的网络的参数初始化方式不同,因为参数初始化的顺序通常依赖于一个有状态的随机数生成器,而顺序是不同的。更多详情请查看 可重复性。惰性模块可以像其他模块一样使用状态字典进行序列化。例如:
>>> lazy_mlp = LazyMLP() >>> # 状态字典显示未初始化的参数 >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', Uninitialized parameter), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', Uninitialized parameter), ('fc2.bias', tensor([0.0019]))])
惰性模块可以加载常规的
torch.nn.Parameter
(即你可以序列化/反序列化已初始化的惰性模块,它们将保持初始化状态)>>> full_mlp = LazyMLP() >>> # 干运行以初始化另一个模块 >>> full_mlp.forward(torch.ones(10, 1)) >>> # 将初始化状态加载到惰性模块中 >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) >>> # 状态字典现在包含有效值 >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', tensor([[-0.3837], [ 0.0907], [ 0.6708], [-0.5223], [-0.9028], [ 0.2851], [-0.4537], [ 0.6813], [ 0.5766], [-0.8678]])), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, 0.2479, 0.1091]])), ('fc2.bias', tensor([0.0019]))])
然而,请注意,如果在加载状态时初始化了参数,则在执行“空运行”时不会替换这些加载的参数。这防止了在不同上下文中使用初始化的模块。