使用 autograd.Function 扩展 torch.func¶
所以你希望将 torch.autograd.Function
与 torch.func
变换如 torch.vmap()
、torch.func.grad()
等一起使用。
有两种主要的使用场景:
您希望调用不包含 PyTorch 操作的代码,并使其与函数转换一起工作。也就是说,将
torch.autograd.Function
的前向/后向/等调用转换为其他系统(如 C++、CUDA、numpy)中的函数。您希望指定自定义梯度规则,例如 JAX 的 custom_vjp/custom_jvp
PyTorch将这两个概念结合到torch.autograd.Function
中。
基本用法¶
本指南假设您熟悉扩展 torch.autograd,
其中解释了如何使用torch.autograd.Function
。
torch.autograd.Function
可以有一个接受 ctx 对象的 forward()
,
或者它可以有一个单独的 forward()
(不接受 ctx
)和一个修改 ctx
对象的 setup_context()
静态方法。
仅支持使用函数转换的后一种方式:
forward()
是执行操作的代码,它不应该接受一个ctx
对象。setup_context(ctx, inputs, output)
是你可以调用ctx
方法的代码。在这里,你应该保存用于反向传播的张量(通过调用ctx.save_for_backward(*tensors)
),或者保存非张量(通过将它们分配给ctx
对象)。
因为 setup_context()
只接受 inputs
和 output
,
所以唯一可以保存的量是输入或输出中的对象(例如张量)或从它们派生的量(如 Tensor.shape
)。
如果你希望保存来自 Function.forward()
的非输入中间激活以用于反向传播,那么你需要将其作为输出从 forward()
返回,以便它被传递给
setup_context()
。
根据变换的不同,
为了支持反向模式自动微分(
torch.func.grad()
,torch.func.vjp()
),torch.autograd.Function
需要一个backward()
静态方法。为了支持
torch.vmap()
,torch.autograd.Function
需要一个vmap()
静态方法。为了支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。支持变换的组合(如
torch.func.jacrev()
,torch.func.jacfwd()
,torch.func.hessian()
)——您可能需要多个上述内容。
为了使 torch.autograd.Function
能够与函数变换任意组合,我们建议除了 forward()
和
setup_context()
之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 操作符组成或调用其他 torch.autograd.Function
(这些函数可能会调用 C++/CUDA/等)。
让我们来看一些常见用例的示例。
示例 1: autograd.Function 调用另一个系统¶
一个常见的情况是同时具有torch.autograd.Function
的forward()和backward()调用另一个系统(如C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# 注意 forward 不接受 ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# 任何需要在反向传播中保存的中间结果都必须作为输出返回。
return (
# 期望的输出
torch.tensor(result, device=device),
# 中间结果,用于反向传播
torch.tensor(ind, device=device),
# 中间结果,用于反向传播
torch.tensor(ind_inv, device=device),
)
# setup_context 负责调用方法和/或将值分配给 ctx 对象。请不要在 setup_context 中进行额外的计算(例如将张量相加)。
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# 注意 output 是你在 forward 中返回的任何内容。
# 如果你返回了多个值,那么 output 是一个包含多个值的元组。
# 如果你返回了一个单一的张量,那么 output 是一个张量。
# 如果你返回了一个包含单一张量的元组,那么 output 是一个包含单一张量的元组。
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# 张量必须通过 ctx.save_for_backward 保存。请不要直接将它们分配到 ctx 对象上。
ctx.save_for_backward(ind, ind_inv)
# 非张量可以通过将它们作为属性分配到 ctx 对象上来保存。
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# 为了使 autograd.Function 能够与函数变换任意组合,除了 forward 和 setup_context 之外的所有静态方法
# 都必须以“可变换”的方式实现;也就是说,它们必须仅由 PyTorch 操作或 autograd.Function 组成。
#
# 例如,这允许我们进行双重反向传播和/或计算二阶梯度。
#
# 我们已经用另一个 autograd.Function,NumpyTake,来实现 NumpySort 的反向传播。
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更方便地使用 NumpySort
(隐藏我们作为输出返回的中间结果,并允许默认参数和关键字参数),我们创建了一个新函数来调用它:
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这里是一个合理性检查:
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一个常见的情况是使用 PyTorch 操作实现的 torch.autograd.Function
。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传播而不是使用 PyTorch 提供的反向传播的原因有:
提高数值稳定性
改变反向传播的性能特征
改变如何处理边缘情况(例如,nans,inf)
修改梯度(例如梯度裁剪)
这是一个关于函数 y = x ** 3
的 torch.autograd.Function
示例,其中我们改变了性能特征(通常在反向传播过程中发生的某些计算,计算 dx,在正向传播过程中发生)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# 在常规的 PyTorch 中,如果我们只是运行 y = x ** 3,那么在反向传播中
# 会计算 dx = 3 * x ** 2。在这个 autograd.Function 中,我们在这里的前向传播中
# 进行了这个计算。
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# 为了使 autograd.Function 能够与高阶梯度一起工作,我们必须添加 `dx` 的梯度贡献。
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用 NumpySort
(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这里是一个计算二阶梯度的合理性检查:
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项¶
警告
请仔细阅读这些关于torch.autograd.Function
与torch.func变换的限制。我们无法捕获许多这些情况并优雅地报错,因此它们将导致未定义行为。
请不要将正在转换、具有 requires_grad=True 或为双张量的张量捕获到 torch.autograd.Function
的方法中。确保完全安全的方法是确保在 torch.autograd.Function
的任何方法中使用的唯一张量必须直接作为输入传递(或通过 ctx 对象),而不是来自 torch.autograd.Function
外部。
torch.autograd.Function
不处理pytrees中的张量(可能包含或不包含张量的任意嵌套Python数据结构)。为了使这些张量被autograd跟踪,它们必须直接作为参数传递给torch.autograd.Function
。这与jax.{custom_vjp, custom_jvp}不同,后者接受pytrees。
请仅使用 save_for_backward()
或
save_for_forward()
来保存张量。
请不要直接将张量或张量集合分配到 ctx 对象上 -
这些张量将不会被跟踪
torch.vmap()
支持¶
要使用带有torch.autograd.Function
的torch.vmap()
,您必须:
提供一个
vmap()
静态方法,告诉我们torch.autograd.Function
在torch.vmap()
下的行为通过设置
generate_vmap_rule=True
,要求我们自动生成它。
自动生成vmap规则¶
如果你的 torch.autograd.Function
满足以下附加约束,那么我们能够为其生成一个 vmap 规则。如果不满足这些约束或你希望在 vmap 下自定义行为,请手动定义一个 vmap 静态方法(见下一节)。
警告
我们不容易检查以下约束并优雅地报错。违反这些约束可能导致未定义行为。
The
torch.autograd.Function
’sforward()
,backward()
(如果存在) 和jvp()
(如果存在) 静态方法必须可以通过torch.vmap()
进行转换。也就是说,它们必须仅由 PyTorch 操作组成(而不是例如 NumPy 或自定义 CUDA 内核)。
示例:
class MyCube(torch.autograd.Function):
# 将 generate_vmap_rule 设置为 True,以要求 PyTorch 自动生成
# 一个 vmap 规则。
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法¶
如果你的 torch.autograd.Function
调用了另一个系统(如 NumPy、C++、CUDA、triton),
那么为了使其与 torch.vmap()
或使用它的变换一起工作,你需要手动定义一个 vmap()
静态方法。
根据您想要使用的变换和您的使用场景,您可能不需要将 vmap()
静态方法添加到所有的 torch.autograd.Function
中:
例如,
torch.func.jacrev()
在反向传播过程中执行vmap()
。 因此,如果你只对使用torch.func.jacrev()
感兴趣,只需要backward()
静态方法可映射。
我们确实建议确保您的所有 torch.autograd.Function
都支持
torch.vmap()
,特别是如果您正在编写第三方库并且希望您的
torch.autograd.Function
与所有 torch.func()
变换的组合一起工作。
从概念上讲,vmap静态方法是负责定义forward()
在torch.vmap()
下的行为。也就是说,它定义了如何将forward()
转换为在具有额外维度的输入上运行(该维度是vmapped的维度)。这与torch.vmap()
在PyTorch操作上的实现类似:对于每个操作,我们定义一个vmap规则(有时也称为“批处理规则”)。
以下是如何定义 vmap()
静态方法:
签名是
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
与传递给forward()
的参数相同。vmap 静态方法负责定义在
forward()
在torch.vmap()
下的行为。也就是说,给定具有额外维度的输入(由in_dims
指定),我们如何计算forward()
的批量版本?对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。 如果该参数不是张量或该参数未被vmapped处理,则为None
, 否则,它是一个整数,指定张量的哪个维度正在被vmapped处理。info
是一个包含额外元数据的集合,这些元数据可能会有所帮助:info.batch_size
指定了正在 vmapped 的维度的尺寸,而info.randomness
是传递给torch.vmap()
的randomness
选项。vmap静态方法的返回值是一个包含
(output, out_dims)
的元组。与in_dims
类似,out_dims
应该与output
具有相同的结构,并且包含每个输出对应的out_dim
,用于指定输出是否具有vmapped维度以及该维度的索引位置。
示例:
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# vmap静态方法的签名是:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# 其中*args与`forward`的参数相同。
@staticmethod
def vmap(info, in_dims, x, dim):
# 对于每个输入(x和dim),in_dims存储一个Optional[int]
# 即:
# - 如果输入没有被vmapped或者输入不是Tensor,则为None
# - 如果输入被vmapped,则为一个整数,表示被vmapped的维度的索引。
x_bdim, _ = in_dims
# "vmap规则"是关于如何在对输入增加一个维度的情况下执行操作的逻辑。在NumpySort中,x有一个额外的维度(x_bdim)。vmap规则很简单,就是再次调用NumpySort,但传递一个不同的`dim`。
x = x.movedim(x_bdim, 0)
# 正确处理负的dim
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# vmap规则必须返回两个元组
# 1. 输出。应该与forward()返回的内容数量相同。
# 2. 每个输出一个Optional[int],指定每个输出是否被vmapped,如果是,则指定被vmapped的维度的索引。
#
# NumpySort.forward返回一个包含3个Tensor的元组。由于我们将被vmapped的维度移动到`x`的前面,它出现在所有输出的维度0。
# 返回值是(output, out_dims) -- output是一个包含3个Tensor的元组,out_dims是一个包含3个Optional[int]的元组
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def
注意
vmap 静态方法应旨在保留整个 Function
的语义。也就是说,(伪代码)grad(vmap(MyFunc))
应该可以替换为 grad(map(MyFunc))
。
如果你的autograd.Function在反向传播过程中有任何自定义行为,请记住这一点。
注意
为 PyTorch 能够通过 generate_vmap_rule=True
生成 vmap 规则的 Function
编写自定义的 vmap 静态方法是合法的使用场景。如果您希望生成的 vmap 规则不符合您所需的语义,您可能希望这样做。
torch.func.jvp()
支持¶
为了支持前向模式自动微分,一个torch.autograd.Function
必须有一个jvp()
静态方法。
详情请参见前向模式自动微分。