• Docs >
  • AOT Autograd - How to use and optimize?
Shortcuts

AOT Autograd - 如何使用和优化?

Open In Colab

背景

在本教程中,我们将学习如何使用AOT Autograd来加速深度学习模型的训练。

作为背景,AOT Autograd 是一个帮助开发者在 PyTorch 上加速训练的工具包。总的来说,它有两个关键特性

  • AOT Autograd 提前追踪前向和后向图。提前存在前向和后向图有助于联合图优化,例如重新计算或激活检查点。

  • AOT Autograd 提供了简单的机制,通过深度学习编译器(如 NVFuser、NNC、TVM 等)来编译提取的前向和后向图。

你将学到什么?

在本教程中,我们将探讨如何使用AOT Autograd与后端编译器结合,以加速PyTorch模型的训练。更具体地说,您将学习

  • 如何使用AOT Autograd?

  • AOT Autograd 如何使用后端编译器执行操作融合?

  • AOT Autograd 如何实现训练特定的优化,例如重新计算?

那么,让我们开始吧。

设置

让我们设置一个简单的模型。

import torch

def fn(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()

使用AOT Autograd

现在,让我们使用AOT Autograd并查看提取的前向和后向图。在内部,AOT使用基于__torch_dispatch__的跟踪机制来提取前向和后向图,并将它们包装在torch.Fx GraphModule容器中。请注意,AOT Autograd跟踪与通常的Fx符号跟踪不同。AOT Autograd使用Fx GraphModule仅用于表示跟踪的图(而不是用于跟踪)。

AOT Autograd 然后将这些前向和后向图发送给用户提供的编译器。所以,让我们编写一个只打印图的编译器。

from functorch.compile import aot_function

# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return fx_module

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos)
    return [cos_1, add_2, cos]
    



def forward(self, add_2, cos, tangents_1):
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

上述代码打印了前向和后向图的Fx图。你可以看到,除了前向传递的原始输入外,前向图还输出了一些额外的张量。这些张量被保存下来用于后向传递的梯度计算。我们稍后在讨论重计算时会回到这些内容。

操作符融合

既然我们已经了解了如何使用AOT Autograd来打印前向和后向图,现在让我们使用AOT Autograd来使用一些实际的深度学习编译器。在本教程中,我们使用PyTorch神经网络编译器(NNC)来为CPU设备执行点操作符融合。对于CUDA设备,一个合适的替代方案是NvFuser。所以,让我们使用NNC

# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile
from functorch.compile import ts_compile

# Lets compile the forward and backward through ts_compile.
aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

让我们对原始函数和AOT Autograd + NNC编译后的函数进行基准测试。

# Lets write a function to benchmark the forward and backward pass
import time
import statistics

def bench(fn, args, prefix):
    warmup = 10
    iterations = 100

    for _ in range(warmup):
        ref = fn(*args)
        ref.sum().backward()
    
    fw_latencies = []
    bw_latencies = []
    for _ in range(iterations):
        for arg in args:
            arg.grad = None

        fw_begin = time.perf_counter()
        ref = fn(*args)
        fw_end = time.perf_counter()

        loss = ref.sum() 

        bw_begin = time.perf_counter()
        loss.backward()
        bw_end = time.perf_counter()

        fw_latencies.append(fw_end - fw_begin)
        bw_latencies.append(bw_end - bw_begin)
    
    avg_fw_latency = statistics.mean(fw_latencies) * 10**6
    avg_bw_latency = statistics.mean(bw_latencies) * 10**6
    print(prefix, "Fwd = " + str(avg_fw_latency) + " us", "Bwd = " + str(avg_bw_latency) + " us", sep=', ')
large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]

# Benchmark the Eager and AOT Autograd functions
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us
AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us

在NNC的帮助下,AOT Autograd加速了前向和后向传递。如果我们查看之前打印的图表,所有的操作符都是逐点的。逐点操作符受内存带宽限制,因此从操作符融合中受益。仔细观察数字,后向传递获得了更高的加速。这是因为前向传递必须输出一些中间张量用于后向传递的梯度计算,这阻止了它节省一些内存读取和写入。然而,这种限制在后向图中不存在。

重新计算(又称激活检查点)

重新计算(通常称为激活检查点)是一种技术,在这种技术中,我们不是在反向传播时保存一些激活值以供使用,而是在反向传播过程中重新计算它们。重新计算可以节省内存,但我们会因此产生性能开销。

然而,在存在融合编译器的情况下,我们可以做得更好。我们可以重新计算适合融合的运算符以节省内存,然后依靠融合编译器来融合重新计算的运算符。这既减少了内存使用,也减少了运行时间。更多详情请参阅此讨论帖子

在这里,我们使用AOT Autograd与NNC来执行类似类型的重新计算。在__torch_dispatch__跟踪结束时,AOT Autograd有一个前向图和联合前向-后向图。然后,AOT Autograd使用一个分区器来隔离前向和后向图。在上面的例子中,我们使用了一个默认的分区器。对于这个实验,我们将使用另一个名为min_cut_rematerialization_partition的分区器来执行更智能的融合感知重新计算。分区器是可配置的,人们可以编写自己的分区器并将其插入AOT Autograd中。

from functorch.compile import min_cut_rematerialization_partition

# Zero out the gradients so we can do a comparison later
a.grad, b.grad, c.grad, d.grad = (None,) * 4

# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.
# This will show us how the recomputation has modified the graph.
aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)
res = aot_fn(a, b, c, d).sum().backward()
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos);  cos = None
    return [cos_1, add_2]
    



def forward(self, add_2, tangents_1):
    cos = torch.ops.aten.cos(add_2)
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

我们可以看到,与默认分区器相比,前向传递现在输出的张量更少,并且在反向传递中重新计算了一些操作。现在让我们尝试使用NNC编译器来执行操作符融合(注意,我们还有一个包装函数 - memory_efficient_fusion,它在内部使用min_cut_rematerialization_partition和Torchscript编译器来实现与以下代码相同的效果)。

# Lets set up the partitioner and NNC compiler.
aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_recompute_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

最后,让我们对不同函数进行基准测试

bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
bench(aot_recompute_nnc_fn, large_inputs, "AOT_Recomp")
Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us
AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us
AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us

我们观察到,前向和后向延迟都比默认的分区器有所改善(并且比急切模式要好得多)。在前向传递中较少的输出和后向传递中较少的输入,以及融合,使得内存带宽利用率更高,从而进一步加速。

实际用法

为了在CUDA设备上的实际使用,我们将AOTAutograd封装在一个方便的包装器中 - memory_efficient_fusion。在GPU上进行融合时使用这个!

from functorch.compile import memory_efficient_fusion