Shortcuts

functorch.combine_state_for_ensemble

functorch.combine_state_for_ensemble(models)func, params, buffers[source]

准备一个用于与vmap()集成的torch.nn.Modules列表。

给定一个相同类的Mnn.Modules的列表,将它们的所有参数和缓冲区堆叠在一起,形成paramsbuffers。结果中的每个参数和缓冲区将具有一个大小为M的额外维度。

combine_state_for_ensemble() 还返回 func,这是 models 中某个模型的功能版本。不能直接运行 func(params, buffers, *args, **kwargs),你可能需要使用 vmap(func, ...)(params, buffers, *args, **kwargs)

这里有一个如何在一个非常简单的模型上进行集成的示例:

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)

警告

所有堆叠在一起的模块必须相同(除了它们的参数/缓冲区的值)。例如,它们应该处于相同的模式(训练模式与评估模式)。

此API可能会发生变化——我们正在研究更好的方法来创建集成,并非常欢迎您提供如何改进此功能的反馈。

警告

我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.combine_state_for_ensemble在PyTorch 2.0中已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.stack_module_state;更多详情请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html