Shortcuts

torch.func.hessian

torch.func.hessian(func, argnums=0)

计算func关于索引为argnum的参数的Hessian,采用前向-反向策略。

前向-反向策略(组合 jacfwd(jacrev(func)))是 一个性能良好的默认选择。可以通过其他组合来计算Hessians,例如 jacfwd()jacrev(),如 jacfwd(jacfwd(func))jacrev(jacrev(func))

Parameters
  • func (函数) – 一个Python函数,接受一个或多个参数,其中必须有一个是Tensor,并返回一个或多个Tensor

  • argnums (intTuple[int]) – 可选,整数或整数元组,指示要对哪些参数计算Hessian。默认值:0。

Returns

返回一个函数,该函数接受与func相同的输入,并返回关于argnumsfunc的海森矩阵。

注意

您可能会看到此API错误显示“未为运算符X实现前向模式AD”。如果是这样,请提交错误报告,我们将优先处理。另一种方法是使用jacrev(jacrev(func)),它具有更好的运算符覆盖率。

一个基本的用法,使用一个 R^N -> R^1 函数,得到一个 N x N 的 Hessian 矩阵:

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # 等价于 jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
优云智算