从 functorch 迁移到 torch.func¶
torch.func,之前被称为“functorch”,是 类似于JAX 的可组合函数变换,适用于PyTorch。
functorch 最初是在 pytorch/functorch 仓库中的一个外部库。我们的目标一直是将 functorch 直接上游到 PyTorch 中,并将其作为核心 PyTorch 库提供。
作为上游的最后一步,我们决定从顶级包(functorch)迁移到成为PyTorch的一部分,以反映函数转换如何直接集成到PyTorch核心中。从PyTorch 2.0开始,我们弃用了import functorch,并要求用户迁移到最新的API,我们将继续维护这些API。import functorch将保留一段时间以保持向后兼容性。
函数转换¶
以下API是以下内容的即插即用替代品: functorch API。 它们完全向后兼容。
functorch API |
PyTorch API(截至PyTorch 2.0) |
|---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您正在使用 torch.autograd.functional API,请尝试使用 torch.func 的等效功能。torch.func 函数变换在许多情况下更具组合性和性能。
torch.autograd.functional API |
torch.func API(截至 PyTorch 2.0) |
|---|---|
NN 模块工具¶
我们已经更改了API,以在NN模块上应用函数变换,使其更好地融入PyTorch设计理念。新的API有所不同,因此请仔细阅读本节内容。
functorch.make_functional¶
torch.func.functional_call() 是 functorch.make_functional 和 functorch.make_functional_with_buffers 的替代品。然而,它并不是一个直接的替代品。
如果你时间紧迫,可以使用
这个gist中的辅助函数
来模拟functorch.make_functional和functorch.make_functional_with_buffers的行为。
我们建议直接使用torch.func.functional_call(),因为它是一个更明确且灵活的API。
具体来说,functorch.make_functional 返回一个功能模块和参数。
功能模块接受参数和模型的输入作为参数。
torch.func.functional_call() 允许使用新的参数和缓冲区以及输入来调用现有模块的前向传递。
以下是如何使用functorch计算模型参数梯度的示例,与torch.func进行对比:
# ---------------
# 使用 functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
以下是如何计算模型参数的雅可比矩阵的示例:
# ---------------
# 使用 functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev 默认计算 argnums=0 的雅可比矩阵。
# 我们将其设置为 1 以计算参数的雅可比矩阵
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
请注意,为了节省内存,您应该只保留一份参数副本。model.named_parameters() 不会复制参数。如果在模型训练中就地更新模型的参数,那么作为模型的 nn.Module 将保留一份参数副本,一切正常。
然而,如果你想在字典中携带你的参数并在原地更新它们,那么会有两个参数的副本:字典中的一个和模型中的一个。在这种情况下,你应该通过将模型转换为元设备来改变它,使其不持有内存,即model.to('meta')。
functorch.combine_state_for_ensemble¶
请使用 torch.func.stack_module_state() 代替
functorch.combine_state_for_ensemble
torch.func.stack_module_state() 返回两个字典,一个是堆叠的参数,另一个是堆叠的缓冲区,然后可以与 torch.vmap() 和 torch.func.functional_call()
用于集成。
例如,以下是如何对一个非常简单的模型进行集成的一个示例:
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# 使用 functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import copy
# 通过将 Tensors 放在 meta 设备上,构造一个没有内存的模型版本。
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# 可以直接对 torch.func.functional_call 进行 vmap,但将其包装在一个函数中可以更清楚地了解正在发生的事情。
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile¶
我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 中编译的前端;我们已经将 AOTAutograd 集成到 PyTorch 的编译流程中。如果您是用户,请改用 torch.compile()。