torch.func¶
torch.func,之前被称为“functorch”,是 类似于JAX 的可组合函数变换,适用于PyTorch。
注意
该库目前处于测试版。 这意味着这些功能通常可以正常工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库的发展。然而,API 可能会根据用户反馈进行更改,并且我们对 PyTorch 操作的覆盖范围并不全面。
如果您对API或希望涵盖的用例有建议,请提交GitHub问题或联系我们。我们很乐意了解您如何使用该库。
什么是可组合的函数变换?¶
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个新函数,该新函数计算不同的量。
torch.func
具有自动微分变换(grad(f)
返回一个计算f
梯度的函数),一个向量化/批处理变换(vmap(f)
返回一个在输入批次上计算f
的函数),以及其他功能。这些函数变换可以任意地相互组合。例如,组合
vmap(grad(f))
计算一个称为每样本梯度的量,这是 PyTorch 目前无法高效计算的。