torch.func.stack_module_state¶
- torch.func.stack_module_state(models) params, buffers¶
准备一个包含 torch.nn.Modules 的列表,以便与
vmap()进行集成。给定一个包含相同类的
M个nn.Modules的列表,返回两个字典,它们将所有参数和缓冲区按名称堆叠在一起。 堆叠的参数是可优化的(即它们是自动求导历史中的新叶子节点,与原始参数无关,可以直接传递给优化器)。这是一个如何对一个非常简单的模型进行集成的示例:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) def wrapper(params, buffers, data): return torch.func.functional_call(model[0], (params, buffers), data) params, buffers = stack_module_state(models) output = vmap(wrapper, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
当存在子模块时,这遵循状态字典命名约定
import torch.nn as nn class Foo(nn.Module): def __init__(self, in_features, out_features): super().__init__() hidden = 4 self.l1 = nn.Linear(in_features, hidden) self.l2 = nn.Linear(hidden, out_features) def forward(self, x): return self.l2(self.l1(x)) num_models = 5 in_features, out_features = 3, 3 models = [Foo(in_features, out_features) for i in range(num_models)] params, buffers = stack_module_state(models) print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
警告
所有堆叠在一起的模块必须相同(除了它们的参数/缓冲区的值)。例如,它们应该处于相同的模式(训练模式与评估模式)。