functorch.make_functional_with_buffers¶
-
functorch.make_functional_with_buffers(model, disable_autograd_tracking=False)[source]¶ make_functional(model, disable_autograd_tracking=False) -> func, params
给定一个
torch.nn.Module,make_functional()提取状态(参数)并返回模型的功能版本,func。这使得可以对model的参数进行转换。func可以如下调用:import torch import torch.nn as nn from functorch import make_functional x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) func(params, x)
这里是一个对模型参数应用梯度变换的示例。
import torch import torch.nn as nn from functorch import make_functional, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, x, t)
如果模型有任何缓冲区,请使用
make_functional_with_buffers()代替。- Parameters
model (torch.nn.Module) – 输入模型。
disable_autograd_tracking (bool) – 用于禁用输出参数的梯度跟踪的标志。 返回的参数与原始模型的参数集无关。如果为False(默认值), 参数将具有
requires_grad=True(即它们将可以通过常规的PyTorch自动梯度进行跟踪), 与原始模型的参数的requires_grad状态匹配。否则,返回的参数将具有requires_grad=False。默认值为False。 如果您计划使用常规的PyTorch自动梯度(例如,如果您想调用.backward()或torch.autograd.grad()),则请设置disable_autograd_tracking=False。 否则,如果您只计划使用functorch的梯度变换, 则请设置disable_autograd_tracking=True以避免不必要地使用PyTorch自动梯度跟踪历史记录。
警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.make_functional_with_buffers在PyTorch 2.0中已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.functional_call;有关更多详细信息,请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html