functorch.hessian¶
-
functorch.hessian(func, argnums=0)[source]¶ 计算
func关于索引argnum处的参数的海森矩阵,采用正向-反向策略。前向-反向策略(组合
jacfwd(jacrev(func)))是 一个良好的默认选择,以获得良好的性能。可以通过其他组合方式计算 Hessians, 例如jacfwd()和jacrev()的组合,如jacfwd(jacfwd(func))或jacrev(jacrev(func))。- Parameters
- Returns
返回一个函数,该函数接受与
func相同的输入,并返回func关于argnums处的参数的Hessian矩阵。
注意
你可能会看到这个API错误提示“forward-mode AD not implemented for operator X”。如果是这样,请提交一个错误报告,我们会优先处理。另一种方法是使用
jacrev(jacrev(func)),它有更好的操作符覆盖范围。使用一个 R^N -> R^1 函数的基本用法会生成一个 N x N 的 Hessian 矩阵:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch.hessian自PyTorch 2.0起已被弃用,并将在未来版本PyTorch >= 2.3中删除。请改用torch.func.hessian;更多详情请参阅PyTorch 2.0发布说明和/或torch.func迁移指南https://pytorch.org/docs/master/func.migrating.html