函子¶
functorch 是 PyTorch 的 JAX-like 可组合函数转换。
警告
我们已经将functorch集成到PyTorch中。作为集成的最后一步,functorch API在PyTorch 2.0中已被弃用。请改用torch.func API,并查看迁移指南和文档以获取更多详细信息。
什么是可组合函数转换?¶
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同量的新函数。
functorch 具有自动微分变换(
grad(f)
返回一个计算f
梯度的函数),向量化/批处理变换(vmap(f)
返回一个在输入批次上计算f
的函数),以及其他功能。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
计算一个称为每样本梯度的量,这是当前PyTorch无法高效计算的。
为什么选择可组合函数转换?¶
在PyTorch中,有许多用例目前难以实现:
计算每个样本的梯度(或其他每个样本的量)
在单台机器上运行模型集合
在MAML的内循环中高效地批量处理任务
高效计算雅可比矩阵和海森矩阵
高效计算批量雅可比矩阵和海森矩阵
组合vmap()
、grad()
和vjp()
转换使我们能够表达上述内容,而无需为每个内容设计单独的子系统。
这种可组合函数转换的想法来自JAX框架。