• Docs >
  • Extending torch.func with autograd.Function
Shortcuts

使用 autograd.Function 扩展 torch.func

所以你希望将 torch.autograd.Functiontorch.func 变换如 torch.vmap()torch.func.grad() 等一起使用。

有两种主要的使用场景:

  • 您希望调用不包含 PyTorch 操作的代码,并使其与函数转换一起工作。也就是说,将 torch.autograd.Function 的前向/后向/等调用转换为其他系统(如 C++、CUDA、numpy)中的函数。

  • 您希望指定自定义梯度规则,例如 JAX 的 custom_vjp/custom_jvp

PyTorch将这两个概念结合到torch.autograd.Function中。

基本用法

本指南假设您熟悉扩展 torch.autograd, 其中解释了如何使用torch.autograd.Function

torch.autograd.Function 可以有一个接受 ctx 对象的 forward(), 或者它可以有一个单独的 forward()(不接受 ctx)和一个修改 ctx 对象的 setup_context() 静态方法。

仅支持使用函数转换的后一种方式:

  • forward() 是执行操作的代码,它不应该接受一个 ctx 对象。

  • setup_context(ctx, inputs, output) 是你可以调用 ctx 方法的代码。在这里,你应该保存用于反向传播的张量(通过调用 ctx.save_for_backward(*tensors)),或者保存非张量(通过将它们分配给 ctx 对象)。

因为 setup_context() 只接受 inputsoutput, 所以唯一可以保存的量是输入或输出中的对象(例如张量)或从它们派生的量(如 Tensor.shape)。 如果你希望保存来自 Function.forward() 的非输入中间激活以用于反向传播,那么你需要将其作为输出从 forward() 返回,以便它被传递给 setup_context()

根据变换的不同,

为了使 torch.autograd.Function 能够与函数变换任意组合,我们建议除了 forward()setup_context() 之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 操作符组成或调用其他 torch.autograd.Function(这些函数可能会调用 C++/CUDA/等)。

让我们来看一些常见用例的示例。

示例 1: autograd.Function 调用另一个系统

一个常见的情况是同时具有torch.autograd.Function的forward()和backward()调用另一个系统(如C++、CUDA、numpy、triton)。

import torch
import numpy as np

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    # 注意 forward 不接受 ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # 任何需要在反向传播中保存的中间结果都必须作为输出返回。
        return (
            # 期望的输出
            torch.tensor(result, device=device),
            # 中间结果,用于反向传播
            torch.tensor(ind, device=device),
            # 中间结果,用于反向传播
            torch.tensor(ind_inv, device=device),
        )

    # setup_context 负责调用方法和/或将值分配给 ctx 对象。请不要在 setup_context 中进行额外的计算(例如将张量相加)。
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # 注意 output 是你在 forward 中返回的任何内容。
        # 如果你返回了多个值,那么 output 是一个包含多个值的元组。
        # 如果你返回了一个单一的张量,那么 output 是一个张量。
        # 如果你返回了一个包含单一张量的元组,那么 output 是一个包含单一张量的元组。
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # 张量必须通过 ctx.save_for_backward 保存。请不要直接将它们分配到 ctx 对象上。
        ctx.save_for_backward(ind, ind_inv)
        # 非张量可以通过将它们作为属性分配到 ctx 对象上来保存。
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # 为了使 autograd.Function 能够与函数变换任意组合,除了 forward 和 setup_context 之外的所有静态方法
        # 都必须以“可变换”的方式实现;也就是说,它们必须仅由 PyTorch 操作或 autograd.Function 组成。
        #
        # 例如,这允许我们进行双重反向传播和/或计算二阶梯度。
        #
        # 我们已经用另一个 autograd.Function,NumpyTake,来实现 NumpySort 的反向传播。
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

现在,为了更方便地使用 NumpySort(隐藏我们作为输出返回的中间结果,并允许默认参数和关键字参数),我们创建了一个新函数来调用它:

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

这里是一个合理性检查:

x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))

示例 2:autograd.Function 指定自定义梯度规则

另一个常见的情况是使用 PyTorch 操作实现的 torch.autograd.Function。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传播而不是使用 PyTorch 提供的反向传播的原因有:

  • 提高数值稳定性

  • 改变反向传播的性能特征

  • 改变如何处理边缘情况(例如,nans,inf)

  • 修改梯度(例如梯度裁剪)

这是一个关于函数 y = x ** 3torch.autograd.Function 示例,其中我们改变了性能特征(通常在反向传播过程中发生的某些计算,计算 dx,在正向传播过程中发生)。

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # 在常规的 PyTorch 中,如果我们只是运行 y = x ** 3,那么在反向传播中
        # 会计算 dx = 3 * x ** 2。在这个 autograd.Function 中,我们在这里的前向传播中
        # 进行了这个计算。
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # 为了使 autograd.Function 能够与高阶梯度一起工作,我们必须添加 `dx` 的梯度贡献。
        result = grad_output * dx + grad_dx * 6 * x
        return result

现在,为了更方便地使用 NumpySort(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:

def my_cube(x):
    result, _ = MyCube.apply(x)
    return result

这里是一个计算二阶梯度的合理性检查:

x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)

限制和注意事项

警告

请仔细阅读这些关于torch.autograd.Function与torch.func变换的限制。我们无法捕获许多这些情况并优雅地报错,因此它们将导致未定义行为。

请不要将正在转换、具有 requires_grad=True 或为双张量的张量捕获到 torch.autograd.Function 的方法中。确保完全安全的方法是确保在 torch.autograd.Function 的任何方法中使用的唯一张量必须直接作为输入传递(或通过 ctx 对象),而不是来自 torch.autograd.Function 外部。

torch.autograd.Function 不处理pytrees中的张量(可能包含或不包含张量的任意嵌套Python数据结构)。为了使这些张量被autograd跟踪,它们必须直接作为参数传递给torch.autograd.Function。这与jax.{custom_vjp, custom_jvp}不同,后者接受pytrees。

请仅使用 save_for_backward()save_for_forward() 来保存张量。 请不要直接将张量或张量集合分配到 ctx 对象上 - 这些张量将不会被跟踪

torch.vmap() 支持

要使用带有torch.autograd.Functiontorch.vmap(),您必须:

自动生成vmap规则

如果你的 torch.autograd.Function 满足以下附加约束,那么我们能够为其生成一个 vmap 规则。如果不满足这些约束或你希望在 vmap 下自定义行为,请手动定义一个 vmap 静态方法(见下一节)。

警告

我们不容易检查以下约束并优雅地报错。违反这些约束可能导致未定义行为。

示例:

class MyCube(torch.autograd.Function):
    # 将 generate_vmap_rule 设置为 True,以要求 PyTorch 自动生成
    # 一个 vmap 规则。
    generate_vmap_rule = True

    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result

def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)

定义 vmap 静态方法

如果你的 torch.autograd.Function 调用了另一个系统(如 NumPy、C++、CUDA、triton), 那么为了使其与 torch.vmap() 或使用它的变换一起工作,你需要手动定义一个 vmap() 静态方法。

根据您想要使用的变换和您的使用场景,您可能不需要将 vmap() 静态方法添加到所有的 torch.autograd.Function 中:

我们确实建议确保您的所有 torch.autograd.Function 都支持 torch.vmap(),特别是如果您正在编写第三方库并且希望您的 torch.autograd.Function 与所有 torch.func() 变换的组合一起工作。

从概念上讲,vmap静态方法是负责定义forward()torch.vmap()下的行为。也就是说,它定义了如何将forward()转换为在具有额外维度的输入上运行(该维度是vmapped的维度)。这与torch.vmap()在PyTorch操作上的实现类似:对于每个操作,我们定义一个vmap规则(有时也称为“批处理规则”)。

以下是如何定义 vmap() 静态方法:

  • 签名是 vmap(info, in_dims: Tuple[Optional[int]], *args),其中 *args 与传递给 forward() 的参数相同。

  • vmap 静态方法负责定义在 forward()torch.vmap() 下的行为。也就是说,给定具有额外维度的输入(由 in_dims 指定),我们如何计算 forward() 的批量版本?

  • 对于args中的每个参数,in_dims都有一个对应的Optional[int]。 如果该参数不是张量或该参数未被vmapped处理,则为None, 否则,它是一个整数,指定张量的哪个维度正在被vmapped处理。

  • info 是一个包含额外元数据的集合,这些元数据可能会有所帮助: info.batch_size 指定了正在 vmapped 的维度的尺寸,而 info.randomness 是传递给 torch.vmap()randomness 选项。

  • vmap静态方法的返回值是一个包含(output, out_dims)的元组。与in_dims类似,out_dims应该与output具有相同的结构,并且包含每个输出对应的out_dim,用于指定输出是否具有vmapped维度以及该维度的索引位置。

示例:

```html
def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

    # vmap静态方法的签名是:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # 其中*args与`forward`的参数相同。
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # 对于每个输入(x和dim),in_dims存储一个Optional[int]
        # 即:
        # - 如果输入没有被vmapped或者输入不是Tensor,则为None
        # - 如果输入被vmapped,则为一个整数,表示被vmapped的维度的索引。
        x_bdim, _ = in_dims

        # "vmap规则"是关于如何在对输入增加一个维度的情况下执行操作的逻辑。在NumpySort中,x有一个额外的维度(x_bdim)。vmap规则很简单,就是再次调用NumpySort,但传递一个不同的`dim`。
        x = x.movedim(x_bdim, 0)
        # 正确处理负的dim
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)

        # vmap规则必须返回两个元组
        # 1. 输出。应该与forward()返回的内容数量相同。
        # 2. 每个输出一个Optional[int],指定每个输出是否被vmapped,如果是,则指定被vmapped的维度的索引。
        #
        # NumpySort.forward返回一个包含3个Tensor的元组。由于我们将被vmapped的维度移动到`x`的前面,它出现在所有输出的维度0。
        # 返回值是(output, out_dims) -- output是一个包含3个Tensor的元组,out_dims是一个包含3个Optional[int]的元组
        return NumpySort.apply(x, dim + 1), (0, 0, 0)

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

    @staticmethod
    def

注意

vmap 静态方法应旨在保留整个 Function 的语义。也就是说,(伪代码)grad(vmap(MyFunc)) 应该可以替换为 grad(map(MyFunc))

如果你的autograd.Function在反向传播过程中有任何自定义行为,请记住这一点。

注意

为 PyTorch 能够通过 generate_vmap_rule=True 生成 vmap 规则的 Function 编写自定义的 vmap 静态方法是合法的使用场景。如果您希望生成的 vmap 规则不符合您所需的语义,您可能希望这样做。

torch.func.jvp() 支持

为了支持前向模式自动微分,一个torch.autograd.Function必须有一个jvp()静态方法。 详情请参见前向模式自动微分