常见问题¶
作者: Mark Saroufim
Does torch.compile 支持训练?¶
torch.compile 支持训练,使用 AOTAutograd 来捕获反向传播:
由 TorchDynamo 的 python
evalframe前端捕获的.forward()图和optimizer.step()。对于 torchdynamo 捕获的每个
.forward()段,它使用 AOTAutograd 生成一个反向图段。每一对前向和后向图(可选地)进行最小割分区,以节省前向和后向之间的最小状态。
前向和反向对被包装在
autograd.function模块中。Usercode 调用
.backward()仍然会触发 eager 的 autograd 引擎, 该引擎将每个编译后的反向图作为一个操作运行,同时也会运行任何非编译的 eager 操作的.backward()函数。
你们支持分布式代码吗?¶
torch.compile 支持 DistributedDataParallel (DDP)。
对其他分布式训练库的支持正在考虑中。
分布式代码在dynamo中具有挑战性的主要原因是AOTAutograd展开了前向和反向传播,并为后端优化提供了两个图。这对于分布式代码来说是一个问题,因为我们希望理想情况下能够将通信操作与计算操作重叠。Eager pytorch通过使用autograd钩子、模块钩子以及对模块状态的修改/变异,以不同的方式实现了DDP/FSDP的这一目标。在dynamo的简单应用中,由于AOTAutograd编译函数与调度器钩子的交互方式,本应在反向传播期间某个操作之后直接运行的钩子可能会被延迟到整个编译的反向操作区域之后。
优化DDP与Dynamo的基本策略概述在 distributed.py 其中主要思想是在DDP桶边界上进行图断开。
当DDP中的每个节点需要与其他节点同步其权重时,它会将梯度和参数组织成桶,从而减少通信时间,并允许节点向其他等待的节点广播其梯度的一部分。
分布式代码中的图断裂意味着你可以期望dynamo及其后端优化分布式程序的计算开销,但不会优化其通信开销。图断裂可能会干扰编译速度的提升,如果减少的图大小剥夺了编译器的融合机会。然而,随着图大小的增加,由于当前大多数计算优化都是局部融合,因此收益会逐渐减少。所以在实践中,这种方法可能是足够的。
我还需要导出整个图吗?¶
对于绝大多数模型,您可能不需要这样做,您可以直接使用
torch.compile(),但在某些情况下,需要完整的图,您可以通过简单地运行
torch.compile(..., nopython=True) 来确保完整的图。这些情况包括:
大规模训练运行,例如需要管道并行和其他高级分片策略的$250K+。
像TensorRT 或AITemplate这样的推理优化器, 它们依赖于比训练优化器更激进的融合。
移动设备上的训练或推理。
未来的工作将包括将通信操作追踪到图中,协调这些操作与计算优化,并优化通信操作。
为什么我的代码崩溃了?¶
如果你的代码在没有启用torch.compile的情况下运行良好,但在启用后开始崩溃,那么最重要的第一步是找出你的失败发生在堆栈的哪个部分。为了排查这个问题,请按照以下步骤操作,只有在之前的步骤成功的情况下才尝试下一步。
torch.compile(..., backend="eager")这只会运行 TorchDynamo 的前向图捕获,然后使用 PyTorch 运行捕获的图。 如果这失败了,那么 TorchDynamo 就有问题。torch.compile(..., backend="aot_eager")这会运行 TorchDynamo 来捕获前向图,然后 AOTAutograd 来跟踪反向图,而无需任何额外的后端编译步骤。PyTorch eager 将用于运行前向和反向图。如果这失败了,那么 AOTAutograd 就有问题。torch.compile(..., backend="inductor")它运行 TorchDynamo 来捕获一个前向图,然后使用 AOTAutograd 和 TorchInductor 编译器来跟踪反向图。如果这失败了,那么 TorchInductor 就有问题
为什么编译速度慢?¶
Dynamo 编译 – TorchDynamo 有一个内置的统计功能,用于收集和显示每个编译阶段所花费的时间。 这些统计信息可以通过在执行
torch._dynamo后调用torch._dynamo.utils.compile_times()来访问。默认情况下,这将返回一个字符串,表示每个 TorchDynamo 函数名称的编译时间。电感器编译– TorchInductor 有一个内置的统计和跟踪功能,用于显示每个编译阶段所花费的时间、输出代码、输出图表可视化和 IR 转储。
env TORCH_COMPILE_DEBUG=1 python repro.py。这是一个调试工具,旨在更容易地调试/理解 TorchInductor 的内部结构,输出结果将类似于这个。该调试跟踪中的每个文件都可以通过torch._inductor.config.trace.*启用/禁用。默认情况下,配置文件和图表都处于禁用状态,因为它们的生成成本较高。有关更多示例,请参见示例调试目录输出。过度重新编译 当 TorchDynamo 编译一个函数(或其中的一部分)时,它会基于对局部变量和全局变量的某些假设,以允许编译器进行优化,并将这些假设表示为在运行时检查特定值的守卫。如果这些守卫中的任何一个失败,Dynamo 将重新编译该函数(或该部分),最多重新编译
torch._dynamo.config.cache_size_limit次。如果你的程序达到了缓存限制,你首先需要确定哪个守卫失败了,以及你的程序的哪一部分触发了它。 重新编译分析器 自动化了将 TorchDynamo 的缓存限制设置为 1 并在仅观察模式下运行你的程序的过程,该模式记录了任何守卫失败的原因。你应该确保运行你的程序至少与你遇到问题时运行的时间(或迭代次数)一样长,并且分析器将在此期间累积统计数据。
from torch._dynamo.utils import CompileProfiler
def my_model():
...
with CompileProfiler() as prof:
profiler_model = torch.compile(my_model, backend=prof)
profiler_model()
print(prof.report())
为什么在生产环境中重新编译?¶
在某些情况下,您可能不希望在程序预热后出现意外的编译。例如,如果您在延迟敏感的应用程序中提供生产流量。为此,TorchDynamo 提供了一种替代模式,其中使用先前编译的图,但不生成新的图:
frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))
你是如何加速我的代码的?¶
有3种主要方法可以加速PyTorch代码:
通过垂直融合进行内核融合,将顺序操作融合以避免过多的读/写操作。例如,融合2个连续的余弦意味着你可以进行1次读取和1次写入,而不是2次读取和2次写入。水平融合:最简单的例子是批处理,其中单个矩阵与一批示例相乘,但更一般的情况是分组GEMM,其中一组矩阵乘法被一起调度
乱序执行:编译器的一种通用优化方法,通过提前查看图中的精确数据依赖关系,我们可以决定执行节点的最佳时机以及哪些缓冲区可以重复使用
自动工作分配:类似于乱序执行点,但通过将图的节点匹配到物理硬件或内存等资源,我们可以设计一个合适的调度计划
上述是加速 PyTorch 代码的一般原则,但不同的后端会各自在优化方面做出不同的权衡。例如,Inductor 首先处理它可以融合的所有内容,然后才生成 Triton 内核。它也可以
Triton 还因为自动内存合并、内存管理和每个流多处理器的调度而提供了加速,并且已被设计用于处理分块计算。
然而,无论你使用哪种后端,最好的做法是使用基准测试并尝试使用PyTorch分析器,直观地检查生成的内核,并尝试自己了解发生了什么。
为什么我没有看到速度提升?¶
图形断点¶
使用dynamo时,您看不到期望的速度提升的主要原因是过多的图断裂。那么,什么是图断裂呢?
给定一个程序,例如:
def some_fun(x):
...
torch.compile(some_fun)(x)
...
Torchdynamo 将尝试将 some_fun() 中的所有 torch/tensor 操作编译成一个 FX 图,但它可能无法将所有内容捕获到一个图中。
一些图形中断的原因对于TorchDynamo来说是不可克服的,例如调用除PyTorch之外的C扩展对TorchDynamo来说是不可见的,并且可以在TorchDynamo无法引入必要的保护措施来确保编译后的程序可以安全重用的情况下执行任意操作。
为了最大化性能,尽量减少图表中断的次数是非常重要的。
识别图表中断的原因¶
要识别程序中的所有图形断点及其相关原因,可以使用 torch._dynamo.explain。此工具在提供的函数上运行 TorchDynamo 并聚合遇到的图形断点。以下是使用示例:
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
print(explanation)
"""
Dynamo 生成了 3 个图,有 2 个图中断和 6 个操作。
中断原因:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
文件 "t2.py",第 16 行,在 toy_example 中
print("woo")
2. generic_jump
文件 "t2.py",第 17 行,在 toy_example 中
if b.sum() < 0:
"""
要在遇到第一个图形断点时抛出错误,您可以使用nopython=True禁用Python回退,如果您曾经使用过基于导出的编译器,这应该是熟悉的。
def toy_example(a, b):
...
torch.compile(toy_example, fullgraph=True, backend=<compiler>)
为什么我更改代码后没有重新编译?¶
如果你通过设置
env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py 启用了动态形状,那么你的代码
在形状变化时不会重新编译。我们已经添加了对动态形状的支持,
这避免了在形状变化小于两倍的情况下重新编译。这在图像大小变化
的CV场景或NLP中的可变序列长度等情况下特别有用。在推理场景中,
通常无法事先知道批量大小是多少,因为你只能从不同的客户端应用中获取数据。
一般来说,TorchDynamo会尽量避免不必要的重新编译。例如,如果TorchDynamo发现了3个图,而你的更改只修改了一个图,那么只有那个图会被重新编译。因此,另一个避免潜在缓慢编译时间的技巧是通过编译模型进行预热,之后随后的编译将会快得多。冷启动编译时间仍然是我们可见跟踪的指标。
为什么我得到的结果不正确?¶
如果你设置环境变量
TORCHDYNAMO_REPRO_LEVEL=4,精度问题也可以被最小化,它采用类似git bisect的模型,完整的重现可能类似于
TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4,我们需要这个的原因是下游编译器会生成代码,无论是Triton代码还是C++后端,这些下游编译器的数值可能在细微之处有所不同,但会对你的训练稳定性产生重大影响。因此,精度调试器对我们检测代码生成中的错误或后端编译器中的错误非常有用。
如果您希望确保在 torch 和 triton 之间的随机数生成是一致的,那么您可以启用 torch._inductor.config.fallback_random = True
为什么我会遇到内存不足的问题?¶
Dynamo 仍然是一个 alpha 产品,因此存在一些导致 OOM 的来源,如果你遇到 OOM,请按以下顺序禁用以下配置,然后在 GitHub 上提交问题,以便我们解决根本问题:
1. 如果你使用动态形状,请尝试禁用它们,我们默认已禁用它们:env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py 2.
在 inductor 中,CUDA 图与 Triton 默认启用,但移除它们可能会缓解一些 OOM 问题:torch._inductor.config.triton.cudagraphs = False。
Does torch.func work with torch.compile (for grad and vmap transforms)?¶
将 torch.func 转换应用于使用 torch.compile 的函数是无效的:
import torch
@torch.compile
def f(x):
return torch.sin(x)
def g(x):
return torch.grad(f)(x)
x = torch.randn(2, 3)
g(x)
这段代码将无法工作。有一个问题 你可以跟踪这个问题。
作为一种变通方法,在 torch.func 函数外部使用 torch.compile:
注意
这是一个实验性功能,可以通过设置torch._dynamo.config.capture_func_transforms=True来使用
import torch
torch._dynamo.config.capture_func_transforms=True
def f(x):
return torch.sin(x)
@torch.compile
def g(x):
return torch.vmap(f)(x)
x = torch.randn(2, 3)
g(x)
在由torch.compile处理的函数内部调用torch.func变换¶
使用 torch.compile 编译 torch.func.grad¶
import torch
torch._dynamo.config.capture_func_transforms=True
def wrapper_fn(x):
return torch.func.grad(lambda x: x.sin().sum())(x)
x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)
编译 torch.vmap 与 torch.compile¶
import torch
torch._dynamo.config.capture_func_transforms=True
def my_fn(x):
return torch.vmap(lambda x: x.sum(1))(x)
x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)
限制¶
目前有一些不支持的情况会导致图中断(即在这些情况下,torch.compile会回退到急切模式下的PyTorch)。我们正在努力改进这一情况,以便在下一个版本(PyTorch 2.2)中得到改善。
1. 被转换的函数的输入和输出必须是张量。 我们目前还不支持类似张量元组这样的内容。
import torch
torch._dynamo.config.capture_func_transforms=True
def fn(x):
x1, x2 = x
return x1 + x2
def my_fn(x):
return torch.func.vmap(fn)(x)
x1 = torch.randn(3, 3, 3)
x2 = torch.randn(3, 3, 3)
# 不支持,回退到急切模式的PyTorch
output = torch.compile(my_fn)((x1, x2))
不支持关键字参数。
import torch
torch._dynamo.config.capture_func_transforms=True
def fn(x, y):
return (x + y).sum()
def my_fn(x, y):
return torch.func.grad(fn)(x, y=y)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
# 不支持,回退到急切模式的PyTorch
output = torch.compile(my_fn)(x, y)
3. 具有可观察副作用的函数。例如,可以修改在函数内部创建的列表,但不能修改在函数外部创建的列表。
import torch
torch._dynamo.config.capture_func_transforms=True
some_list = []
def f(x, y):
some_list.append(1)
return x + y
def my_fn(x, y):
return torch.func.vmap(f)(x, y)
x = torch.ones(2, 3)
y = torch.randn(2, 3)
# 不支持,回退到急切模式的PyTorch
output = torch.compile(my_fn)(x, y)
torch.vmap覆盖一个调用以下列表中一个或多个操作符的函数。
注意
‘stride’, ‘requires_grad’, ‘storage_offset’, ‘layout’, ‘data’, ‘is_coalesced’, ‘is_complex’, ‘is_conj’, ‘is_contiguous’, ‘is_cpu’, ‘is_cuda’, ‘is_distributed’, ‘is_floating_point’, ‘is_inference’, ‘is_ipu’, ‘is_leaf’, ‘is_meta’, ‘is_mkldnn’, ‘is_mps’, ‘is_neg’, ‘is_nested’, ‘is_nonzero’, ‘is_ort’, ‘is_pinned’, ‘is_quantized’, ‘is_same_size’, ‘is_set_to’, ‘is_shared’, ‘is_signed’, ‘is_sparse’, ‘is_sparse_csr’, ‘is_vulkan’, ‘is_xla’, ‘is_xpu’
import torch
torch._dynamo.config.capture_func_transforms=True
def bad_fn(x):
x.stride()
return x
def my_fn(x):
return torch.func.vmap(bad_fn)(x)
x = torch.randn(3, 3, 3)
# 不支持,回退到急切模式的PyTorch
output = torch.compile(my_fn)(x)
编译除支持的函数之外的函数(逃生舱)¶
对于其他转换,作为一种解决方法,使用 torch._dynamo.allow_in_graph
allow_in_graph 是一个逃生舱。如果你的代码无法与
torch.compile 一起使用,后者会内省 Python 字节码,但如果你认为它可以通过符号追踪方法(如 jax.jit)工作,那么请使用
allow_in_graph。
通过使用 allow_in_graph 来注释一个函数,您必须确保您的代码满足以下要求:
函数中的所有输出仅依赖于输入,不依赖于任何捕获的张量。
你的函数是函数式的。也就是说,它不会改变任何状态。这可能会放宽;我们实际上支持从外部看起来是函数式的函数:它们可能包含就地的 PyTorch 操作,但不得改变全局状态或函数的输入。
您的函数不会引发数据相关的错误。
import torch
@torch.compile
def f(x):
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
x = torch.randn(2, 3)
f(x)
一个常见的陷阱是使用 allow_in_graph 来注释一个调用 nn.Module 的函数。这是因为输出现在依赖于 nn.Module 的参数。要使其工作,请使用 torch.func.functional_call 来提取模块状态。
NumPy 是否与 torch.compile 兼容?¶
从2.1版本开始,torch.compile 理解原生NumPy程序,这些程序在NumPy数组上工作,并且理解混合的PyTorch-NumPy程序,这些程序通过 x.numpy()、torch.from_numpy 及相关函数在PyTorch和NumPy之间进行转换。
哪些NumPy功能受torch.compile支持?¶
NumPy 在 torch.compile 中遵循 NumPy 2.0 预发布版本。
通常情况下,torch.compile 能够追踪大多数 NumPy 构造,
并且在无法追踪时,它会回退到急切模式并让 NumPy 执行那部分代码。即便如此,在某些特性上,torch.compile 的语义
与 NumPy 的语义略有不同:
NumPy 标量:我们将其建模为 0-D 数组。也就是说,
np.float32(3)在torch.compile下返回一个 0-D 数组。为了避免图中断,最好使用这个 0-D 数组。如果这破坏了您的代码,您可以通过将 NumPy 标量转换为相关的 Python 标量类型bool/int/float来解决此问题。负步长:
np.flip和使用负步长的切片会返回一个副本。类型提升:NumPy 的类型提升将在 NumPy 2.0 中发生变化。新规则在 NEP 50 中描述。
torch.compile实现了 NEP 50,而不是即将被弃用的当前规则。{tril,triu}_indices_from/{tril,triu}_indices返回数组而不是数组元组。
对于我们不支持追踪的其他功能,我们会优雅地回退到 NumPy 进行执行:
非数字数据类型,如日期时间、字符串、字符、空值、结构化数据类型和记录数组。
长数据类型
np.float128/np.complex256和一些无符号数据类型np.uint16/np.uint32/np.uint64。ndarray子类。掩码数组。
像
axes=[(n,k),(k,m)->(n,m)]这样的深奥的 ufunc 机制和 ufunc 方法(例如,np.add.reduce)。排序 / 排序
complex64/complex128数组。NumPy
np.poly1d和np.polynomial。位置参数
out1, out2在具有2个或更多返回值的函数中(out=tuple有效)。__array_function__,__array_interface__和__array_wrap__。ndarray.ctypes属性。
我可以使用 torch.compile 编译 NumPy 代码吗?¶
当然可以!torch.compile 原生理解 NumPy 代码,并将其视为 PyTorch 代码。为此,只需使用 torch.compile 装饰器包装 NumPy 代码。
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
使用环境变量 TORCH_LOGS=output_code 执行此示例时,我们可以看到
torch.compile 能够将乘法和求和融合到一个 C++ 内核中。
它还能够使用 OpenMP 并行执行它们(原生 NumPy 是单线程的)。
这可以使您的 NumPy 代码速度提高 n 倍,其中 n 是处理器中的核心数量!
以这种方式跟踪NumPy代码还支持在编译代码中进行图形中断。
我可以在CUDA上执行NumPy代码并通过torch.compile计算梯度吗?¶
是的,你可以!为此,你可以在一个torch.device("cuda")上下文中简单地执行你的代码。考虑以下示例
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
在这个例子中,numpy_fn 将在 CUDA 中执行。为了实现这一点,torch.compile 会自动将 X 和 Y 从 CPU 移动到 CUDA,然后将结果 Z 从 CUDA 移动到 CPU。如果我们在同一个程序运行中多次执行这个函数,我们可能希望避免所有这些相当昂贵的内存复制操作。为此,我们只需要调整我们的 numpy_fn,使其接受 CUDA 张量并返回张量。我们可以通过使用 torch.compiler.wrap_numpy 来实现这一点:
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在这里,我们显式地在CUDA内存中创建张量,并将它们传递给函数,该函数在CUDA设备上执行所有计算。
wrap_numpy 负责将任何 torch.Tensor 输入标记为在 torch.compile 级别具有 np.ndarray 语义的输入。在编译器内部标记张量是一个非常廉价的操作,因此在运行时不会发生数据复制或数据移动。
使用这个装饰器,我们还可以通过NumPy代码进行微分!
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad 现在保存了计算的梯度
print(X.grad)
我们一直在使用 fullgraph=True,因为在这种情况下图中断是有问题的。
当发生图中断时,我们需要物化 NumPy 数组。由于 NumPy 数组
没有 device 或 requires_grad 的概念,这些信息在图中断期间会丢失。
我们不能通过图中断传播梯度,因为图中断代码可能会执行我们不知道如何区分的任意代码。另一方面,在CUDA执行的情况下,我们可以像第一个示例中那样,通过使用torch.device("cuda")上下文管理器来解决这个问题:
@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
prod = X[:, :, None] * Y[:, None, :]
print("哎呀,图断了!")
return np.sum(prod, axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在图中断期间,中间张量仍然需要移动到CPU,但在图中断后恢复跟踪时,图的其余部分仍然在CUDA上进行跟踪。鉴于这种CUDA <> CPU和CPU <> CUDA的移动,图中断在NumPy上下文中相当昂贵,应尽量避免,但至少它们允许通过复杂的代码片段进行跟踪。
如何在torch.compile下调试NumPy代码?¶
调试JIT编译的代码具有挑战性,考虑到现代编译器的复杂性和它们引发的令人畏惧的错误。 关于如何在torch.compile中诊断运行时错误的教程 包含了一些关于如何应对这一任务的技巧和窍门。
如果上述内容不足以准确定位问题的来源,我们还可以使用一些其他特定于NumPy的工具。我们可以通过禁用通过NumPy函数的跟踪来判断错误是否完全存在于PyTorch代码中:
from torch._dynamo import config
config.trace_numpy = False
如果错误存在于被追踪的 NumPy 代码中,我们可以通过导入 import torch._numpy as np 使用 PyTorch 作为后端来急切地执行 NumPy 代码(不使用 torch.compile)。
这应该仅用于 调试目的,绝不是 PyTorch API 的替代品,因为它 性能低得多,并且作为一个
私有 API,可能会在没有通知的情况下更改。无论如何,torch._numpy 是
在 PyTorch 基础上实现的 NumPy 的 Python 实现,并且它被 torch.compile 内部使用,
将 NumPy 代码转换为 Pytorch 代码。它相当容易阅读和修改,
所以如果你在其中发现任何错误,欢迎提交一个 PR 来修复它,或者简单地打开一个 issue。
如果在导入 torch._numpy as np 时程序无法正常工作,那么问题很可能出在 TorchDynamo 中。如果是这种情况,请随时提交一个包含 最小复现示例 的问题。
我对一些NumPy代码进行了torch.compile,但没有看到任何速度提升。¶
最好的起点是 如何调试这些torch.compile问题的通用建议教程。
由于使用了不支持的功能,可能会发生一些图形中断。请参阅
torch.compile 支持哪些 NumPy 功能?。更广泛地说,记住一些广泛使用的 NumPy 功能与编译器不兼容是很有用的。例如,就地修改使得在编译器中推理变得困难,并且通常比它们的外部对应物产生更差的性能。因此,最好避免使用它们。同样适用于使用 out= 参数。相反,倾向于使用外部操作,并让 torch.compile 优化内存使用。同样适用于依赖于数据的操作,如通过布尔掩码进行掩码索引,或依赖于数据的控制流,如 if 或 while 结构。
用于细粒度跟踪的API是什么?¶
在某些情况下,您可能需要从torch.compile编译中排除代码的小部分。本节提供了一些答案,您可以在TorchDynamo API中找到更多关于细粒度追踪的信息。
如何在函数上绘制断点?¶
在函数上设置断点不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的使用场景。以下是一些您可能需要考虑的最常见用例:
如果你想禁用此函数帧及其递归调用的帧的编译,请使用
torch._dynamo.disable。如果你想要某个特定的操作符,例如
fbgemm使用 eager 模式,请使用torch._dynamo.disallow_in_graph。
一些不常见的用例包括:
如果你想在函数框架上禁用 TorchDynamo,但在递归调用的框架上重新启用它——使用
torch._dynamo.disable(recursive=False)。如果你想防止函数帧的内联 – 在你想要防止内联的函数开头使用
torch._dynamo.graph_break。
什么是 torch._dynamo.disable 和 torch._dynamo.disallow_in_graph 之间的区别¶
Disallow-in-graph 在操作符级别工作,或者更具体地说,是在 TorchDynamo 提取的图中看到的操作符。
Disable 作用于函数框架级别,并决定 TorchDynamo 是否应该查看该函数框架。
什么是 torch._dynamo.disable 和 torch._dynamo_skip 之间的区别¶
注意
torch._dynamo_skip 已被弃用。
您很可能需要 torch._dynamo.disable。但在不太可能的情况下,您可能需要更精细的控制。假设您只想在 a_fn 函数上禁用跟踪,但希望在 aa_fn 和 ab_fn 中继续跟踪。下图演示了这种用例:
在这种情况下,您可以使用 torch._dynamo.disable(recursive=False)。
在之前的版本中,此功能由 torch._dynamo.skip 提供。
现在,这由 torch._dynamo.disable 中的 recursive 标志支持。