Shortcuts

functorch.make_functional_with_buffers

functorch.make_functional_with_buffers(model, disable_autograd_tracking=False)[source]

make_functional(model, disable_autograd_tracking=False) -> func, params

给定一个torch.nn.Modulemake_functional()提取状态(参数)并返回模型的功能版本,func。这使得可以对model的参数进行转换。

func 可以如下调用:

import torch
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)

这里是一个对模型参数应用梯度变换的示例。

import torch
import torch.nn as nn
from functorch import make_functional, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)

def compute_loss(params, x, t):
    y = func(params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(params, x, t)

如果模型有任何缓冲区,请使用make_functional_with_buffers()代替。

Parameters
  • model (torch.nn.Module) – 输入模型。

  • disable_autograd_tracking (bool) – 用于禁用输出参数的梯度跟踪的标志。 返回的参数与原始模型的参数集无关。如果为False(默认值), 参数将具有requires_grad=True(即它们将可以通过常规的PyTorch自动梯度进行跟踪), 与原始模型的参数的requires_grad状态匹配。否则,返回的参数将具有requires_grad=False。默认值为False。 如果您计划使用常规的PyTorch自动梯度(例如,如果您想调用.backward()torch.autograd.grad()),则请设置disable_autograd_tracking=False。 否则,如果您只计划使用functorch的梯度变换, 则请设置disable_autograd_tracking=True以避免不必要地使用PyTorch自动梯度跟踪历史记录。

警告

我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.make_functional_with_buffers在PyTorch 2.0中已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.functional_call;有关更多详细信息,请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html