Shortcuts

functorch.grad_and_value

functorch.grad_and_value(func, argnums=0, has_aux=False)[source]

返回一个函数,用于计算梯度和原始或前向计算的元组。

Parameters
  • func (可调用) – 一个接受一个或多个参数的Python函数。 必须返回一个单元素张量。如果指定了has_aux 等于True,函数可以返回一个单元素张量和其他辅助对象的元组:(output, aux)

  • argnums (intTuple[int]) – 指定要计算梯度的参数。argnums 可以是单个整数或整数元组。默认值:0。

  • has_aux (bool) – 标志,表示 func 返回一个张量和其他辅助对象:(output, aux)。默认值:False。

Returns

计算相对于其输入和前向计算的梯度元组的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量(或多个张量)和前向计算的元组。如果指定了has_aux等于True,则返回梯度的元组和带有输出辅助对象的前向计算的元组。如果argnums是一个整数元组,则返回相对于每个argnums值的输出梯度的元组和前向计算的元组。

参见 grad() 的示例

警告

我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.grad_and_value自PyTorch 2.0起已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.grad_and_value;有关更多详细信息,请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html