torch.func.hessian¶
- torch.func.hessian(func, argnums=0)¶
计算
func关于索引为argnum的参数的Hessian,采用前向-反向策略。前向-反向策略(组合
jacfwd(jacrev(func)))是 一个性能良好的默认选择。可以通过其他组合来计算Hessians,例如jacfwd()和jacrev(),如jacfwd(jacfwd(func))或jacrev(jacrev(func))。- Parameters
- Returns
返回一个函数,该函数接受与
func相同的输入,并返回关于argnums的func的海森矩阵。
注意
您可能会看到此API错误显示“未为运算符X实现前向模式AD”。如果是这样,请提交错误报告,我们将优先处理。另一种方法是使用
jacrev(jacrev(func)),它具有更好的运算符覆盖率。一个基本的用法,使用一个 R^N -> R^1 函数,得到一个 N x N 的 Hessian 矩阵:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # 等价于 jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))