AOT Autograd - 如何使用和优化?¶
背景¶
在本教程中,我们将学习如何使用AOT Autograd来加速深度学习模型的训练。
作为背景,AOT Autograd 是一个帮助开发者在 PyTorch 上加速训练的工具包。总的来说,它有两个关键特性
AOT Autograd 提前追踪前向和后向图。提前存在前向和后向图有助于联合图优化,例如重新计算或激活检查点。
AOT Autograd 提供了简单的机制,通过深度学习编译器(如 NVFuser、NNC、TVM 等)来编译提取的前向和后向图。
你将学到什么?¶
在本教程中,我们将探讨如何使用AOT Autograd与后端编译器结合,以加速PyTorch模型的训练。更具体地说,您将学习
如何使用AOT Autograd?
AOT Autograd 如何使用后端编译器执行操作融合?
AOT Autograd 如何实现训练特定的优化,例如重新计算?
那么,让我们开始吧。
设置¶
让我们设置一个简单的模型。
import torch
def fn(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()
使用AOT Autograd¶
现在,让我们使用AOT Autograd并查看提取的前向和后向图。在内部,AOT使用基于__torch_dispatch__的跟踪机制来提取前向和后向图,并将它们包装在torch.Fx GraphModule容器中。请注意,AOT Autograd跟踪与通常的Fx符号跟踪不同。AOT Autograd使用Fx GraphModule仅用于表示跟踪的图(而不是用于跟踪)。
AOT Autograd 然后将这些前向和后向图发送给用户提供的编译器。所以,让我们编写一个只打印图的编译器。
from functorch.compile import aot_function
# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
print(fx_module.code)
return fx_module
# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
def forward(self, primals_1, primals_2, primals_3, primals_4):
add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None
add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None
add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None
cos = torch.ops.aten.cos(add_2)
cos_1 = torch.ops.aten.cos(cos)
return [cos_1, add_2, cos]
def forward(self, add_2, cos, tangents_1):
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(add_2); add_2 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1, mul_1, mul_1, mul_1]
上述代码打印了前向和后向图的Fx图。你可以看到,除了前向传递的原始输入外,前向图还输出了一些额外的张量。这些张量被保存下来用于后向传递的梯度计算。我们稍后在讨论重计算时会回到这些内容。
操作符融合¶
既然我们已经了解了如何使用AOT Autograd来打印前向和后向图,现在让我们使用AOT Autograd来使用一些实际的深度学习编译器。在本教程中,我们使用PyTorch神经网络编译器(NNC)来为CPU设备执行点操作符融合。对于CUDA设备,一个合适的替代方案是NvFuser。所以,让我们使用NNC
# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile
from functorch.compile import ts_compile
# Lets compile the forward and backward through ts_compile.
aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)
# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)
让我们对原始函数和AOT Autograd + NNC编译后的函数进行基准测试。
# Lets write a function to benchmark the forward and backward pass
import time
import statistics
def bench(fn, args, prefix):
warmup = 10
iterations = 100
for _ in range(warmup):
ref = fn(*args)
ref.sum().backward()
fw_latencies = []
bw_latencies = []
for _ in range(iterations):
for arg in args:
arg.grad = None
fw_begin = time.perf_counter()
ref = fn(*args)
fw_end = time.perf_counter()
loss = ref.sum()
bw_begin = time.perf_counter()
loss.backward()
bw_end = time.perf_counter()
fw_latencies.append(fw_end - fw_begin)
bw_latencies.append(bw_end - bw_begin)
avg_fw_latency = statistics.mean(fw_latencies) * 10**6
avg_bw_latency = statistics.mean(bw_latencies) * 10**6
print(prefix, "Fwd = " + str(avg_fw_latency) + " us", "Bwd = " + str(avg_bw_latency) + " us", sep=', ')
large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]
# Benchmark the Eager and AOT Autograd functions
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us
AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us
在NNC的帮助下,AOT Autograd加速了前向和后向传递。如果我们查看之前打印的图表,所有的操作符都是逐点的。逐点操作符受内存带宽限制,因此从操作符融合中受益。仔细观察数字,后向传递获得了更高的加速。这是因为前向传递必须输出一些中间张量用于后向传递的梯度计算,这阻止了它节省一些内存读取和写入。然而,这种限制在后向图中不存在。
重新计算(又称激活检查点)¶
重新计算(通常称为激活检查点)是一种技术,在这种技术中,我们不是在反向传播时保存一些激活值以供使用,而是在反向传播过程中重新计算它们。重新计算可以节省内存,但我们会因此产生性能开销。
然而,在存在融合编译器的情况下,我们可以做得更好。我们可以重新计算适合融合的运算符以节省内存,然后依靠融合编译器来融合重新计算的运算符。这既减少了内存使用,也减少了运行时间。更多详情请参阅此讨论帖子。
在这里,我们使用AOT Autograd与NNC来执行类似类型的重新计算。在__torch_dispatch__跟踪结束时,AOT Autograd有一个前向图和联合前向-后向图。然后,AOT Autograd使用一个分区器来隔离前向和后向图。在上面的例子中,我们使用了一个默认的分区器。对于这个实验,我们将使用另一个名为min_cut_rematerialization_partition的分区器来执行更智能的融合感知重新计算。分区器是可配置的,人们可以编写自己的分区器并将其插入AOT Autograd中。
from functorch.compile import min_cut_rematerialization_partition
# Zero out the gradients so we can do a comparison later
a.grad, b.grad, c.grad, d.grad = (None,) * 4
# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.
# This will show us how the recomputation has modified the graph.
aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)
res = aot_fn(a, b, c, d).sum().backward()
def forward(self, primals_1, primals_2, primals_3, primals_4):
add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None
add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None
add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None
cos = torch.ops.aten.cos(add_2)
cos_1 = torch.ops.aten.cos(cos); cos = None
return [cos_1, add_2]
def forward(self, add_2, tangents_1):
cos = torch.ops.aten.cos(add_2)
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(add_2); add_2 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1, mul_1, mul_1, mul_1]
我们可以看到,与默认分区器相比,前向传递现在输出的张量更少,并且在反向传递中重新计算了一些操作。现在让我们尝试使用NNC编译器来执行操作符融合(注意,我们还有一个包装函数 - memory_efficient_fusion,它在内部使用min_cut_rematerialization_partition和Torchscript编译器来实现与以下代码相同的效果)。
# Lets set up the partitioner and NNC compiler.
aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)
# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_recompute_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)
最后,让我们对不同函数进行基准测试
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
bench(aot_recompute_nnc_fn, large_inputs, "AOT_Recomp")
Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us
AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us
AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us
我们观察到,前向和后向延迟都比默认的分区器有所改善(并且比急切模式要好得多)。在前向传递中较少的输出和后向传递中较少的输入,以及融合,使得内存带宽利用率更高,从而进一步加速。
实际用法¶
为了在CUDA设备上的实际使用,我们将AOTAutograd封装在一个方便的包装器中 - memory_efficient_fusion。在GPU上进行融合时使用这个!
from functorch.compile import memory_efficient_fusion