functorch.combine_state_for_ensemble¶
-
functorch.combine_state_for_ensemble(models) → func, params, buffers[source]¶ 准备一个用于与
vmap()集成的torch.nn.Modules列表。给定一个相同类的
M个nn.Modules的列表,将它们的所有参数和缓冲区堆叠在一起,形成params和buffers。结果中的每个参数和缓冲区将具有一个大小为M的额外维度。combine_state_for_ensemble()还返回func,这是models中某个模型的功能版本。不能直接运行func(params, buffers, *args, **kwargs),你可能需要使用vmap(func, ...)(params, buffers, *args, **kwargs)这里有一个如何在一个非常简单的模型上进行集成的示例:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
警告
所有堆叠在一起的模块必须相同(除了它们的参数/缓冲区的值)。例如,它们应该处于相同的模式(训练模式与评估模式)。
此API可能会发生变化——我们正在研究更好的方法来创建集成,并非常欢迎您提供如何改进此功能的反馈。
警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.combine_state_for_ensemble在PyTorch 2.0中已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.stack_module_state;更多详情请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html