Shortcuts

从 functorch 迁移到 torch.func

torch.func,之前被称为“functorch”,是 类似于JAX 的可组合函数变换,适用于PyTorch。

functorch 最初是在 pytorch/functorch 仓库中的一个外部库。我们的目标一直是将 functorch 直接上游到 PyTorch 中,并将其作为核心 PyTorch 库提供。

作为上游的最后一步,我们决定从顶级包(functorch)迁移到成为PyTorch的一部分,以反映函数转换如何直接集成到PyTorch核心中。从PyTorch 2.0开始,我们弃用了import functorch,并要求用户迁移到最新的API,我们将继续维护这些API。import functorch将保留一段时间以保持向后兼容性。

函数转换

以下API是以下内容的即插即用替代品: functorch API。 它们完全向后兼容。

functorch API

PyTorch API(截至PyTorch 2.0)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您正在使用 torch.autograd.functional API,请尝试使用 torch.func 的等效功能。torch.func 函数变换在许多情况下更具组合性和性能。

torch.autograd.functional API

torch.func API(截至 PyTorch 2.0)

torch.autograd.functional.vjp()

torch.func.grad()torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev()torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN 模块工具

我们已经更改了API,以在NN模块上应用函数变换,使其更好地融入PyTorch设计理念。新的API有所不同,因此请仔细阅读本节内容。

functorch.make_functional

torch.func.functional_call()functorch.make_functionalfunctorch.make_functional_with_buffers 的替代品。然而,它并不是一个直接的替代品。

如果你时间紧迫,可以使用 这个gist中的辅助函数 来模拟functorch.make_functional和functorch.make_functional_with_buffers的行为。 我们建议直接使用torch.func.functional_call(),因为它是一个更明确且灵活的API。

具体来说,functorch.make_functional 返回一个功能模块和参数。 功能模块接受参数和模型的输入作为参数。 torch.func.functional_call() 允许使用新的参数和缓冲区以及输入来调用现有模块的前向传递。

以下是如何使用functorch计算模型参数梯度的示例,与torch.func进行对比:

# ---------------
# 使用 functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

以下是如何计算模型参数的雅可比矩阵的示例:

# ---------------
# 使用 functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev 默认计算 argnums=0 的雅可比矩阵。
# 我们将其设置为 1 以计算参数的雅可比矩阵
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

请注意,为了节省内存,您应该只保留一份参数副本。model.named_parameters() 不会复制参数。如果在模型训练中就地更新模型的参数,那么作为模型的 nn.Module 将保留一份参数副本,一切正常。

然而,如果你想在字典中携带你的参数并在原地更新它们,那么会有两个参数的副本:字典中的一个和模型中的一个。在这种情况下,你应该通过将模型转换为元设备来改变它,使其不持有内存,即model.to('meta')

functorch.combine_state_for_ensemble

请使用 torch.func.stack_module_state() 代替 functorch.combine_state_for_ensemble torch.func.stack_module_state() 返回两个字典,一个是堆叠的参数,另一个是堆叠的缓冲区,然后可以与 torch.vmap()torch.func.functional_call() 用于集成。

例如,以下是如何对一个非常简单的模型进行集成的一个示例:

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# 使用 functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# 使用 torch.func (自 PyTorch 2.0 起)
# ------------------------------------
import copy

# 通过将 Tensors 放在 meta 设备上,构造一个没有内存的模型版本。
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# 可以直接对 torch.func.functional_call 进行 vmap,但将其包装在一个函数中可以更清楚地了解正在发生的事情。
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 中编译的前端;我们已经将 AOTAutograd 集成到 PyTorch 的编译流程中。如果您是用户,请改用 torch.compile()

优云智算