Shortcuts

函子

functorch 是 PyTorch 的 JAX-like 可组合函数转换。

警告

我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch API在PyTorch 2.0中已被弃用。请改用torch.func API,并查看迁移指南文档以获取更多详细信息。

什么是可组合函数转换?

  • “函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。

  • functorch 具有自动微分变换(grad(f) 返回一个计算 f 梯度的函数),向量化/批处理变换(vmap(f) 返回一个在输入批次上计算 f 的函数),以及其他功能。

  • 这些函数变换可以任意组合。例如,组合vmap(grad(f))计算一个称为每样本梯度的量,这是当前PyTorch无法高效计算的。

为什么选择可组合函数转换?

在PyTorch中,有许多用例目前难以实现:

  • 计算每个样本的梯度(或其他每个样本的量)

  • 在单台机器上运行模型集合

  • 在MAML的内循环中高效地批量处理任务

  • 高效计算雅可比矩阵和海森矩阵

  • 高效计算批量雅可比矩阵和海森矩阵

组合vmap()grad()vjp()转换使我们能够表达上述内容,而无需为每个内容设计单独的子系统。 这种可组合函数转换的想法来自JAX框架