Shortcuts

torch.func

torch.func,之前被称为“functorch”,是 类似于JAX 的可组合函数变换,适用于PyTorch。

注意

该库目前处于测试版。 这意味着这些功能通常可以正常工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库的发展。然而,API 可能会根据用户反馈进行更改,并且我们对 PyTorch 操作的覆盖范围并不全面。

如果您对API或希望涵盖的用例有建议,请提交GitHub问题或联系我们。我们很乐意了解您如何使用该库。

什么是可组合的函数变换?

  • “函数变换”是一种高阶函数,它接受一个数值函数并返回一个新函数,该新函数计算不同的量。

  • torch.func 具有自动微分变换(grad(f) 返回一个计算 f 梯度的函数),一个向量化/批处理变换(vmap(f) 返回一个在输入批次上计算 f 的函数),以及其他功能。

  • 这些函数变换可以任意地相互组合。例如,组合 vmap(grad(f)) 计算一个称为每样本梯度的量,这是 PyTorch 目前无法高效计算的。

为什么使用可组合的函数变换?

在PyTorch中,有一些用例目前做起来比较棘手:

  • 计算每个样本的梯度(或其他每个样本的量)

  • 在单台机器上运行模型集成

  • 在内循环中高效地将任务批处理在一起

  • 高效计算雅可比矩阵和海森矩阵

  • 高效计算批量雅可比矩阵和海森矩阵

组合 vmap()grad()vjp() 变换使我们能够表达上述内容,而无需为每个内容设计一个单独的子系统。 这种可组合函数变换的思想来自 JAX 框架