Shortcuts

functorch.hessian

functorch.hessian(func, argnums=0)[source]

计算func关于索引argnum处的参数的海森矩阵,采用正向-反向策略。

前向-反向策略(组合 jacfwd(jacrev(func)))是 一个良好的默认选择,以获得良好的性能。可以通过其他组合方式计算 Hessians, 例如 jacfwd()jacrev() 的组合,如 jacfwd(jacfwd(func))jacrev(jacrev(func))

Parameters
  • func (function) – 一个Python函数,它接受一个或多个参数,其中一个必须是Tensor,并返回一个或多个Tensor

  • argnums (intTuple[int]) – 可选的,整数或整数元组, 指定相对于哪些参数获取Hessian矩阵。 默认值:0。

Returns

返回一个函数,该函数接受与func相同的输入,并返回func关于argnums处的参数的Hessian矩阵。

注意

你可能会看到这个API错误提示“forward-mode AD not implemented for operator X”。如果是这样,请提交一个错误报告,我们会优先处理。另一种方法是使用jacrev(jacrev(func)),它有更好的操作符覆盖范围。

使用一个 R^N -> R^1 函数的基本用法会生成一个 N x N 的 Hessian 矩阵:

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

警告

我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.hessian自PyTorch 2.0起已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.hessian;更多详情请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html