注意
点击这里下载完整的示例代码
torch.export 教程¶
创建于:2023年10月02日 | 最后更新:2024年7月22日 | 最后验证:2024年11月05日
作者: William Wen, Zhengxu Chen, Angela Yi
警告
torch.export
及其相关功能目前处于原型状态,可能会发生向后兼容性的重大更改。本教程提供了截至 PyTorch 2.3 版本的 torch.export
使用情况的快照。
torch.export()
是 PyTorch 2.X 中将 PyTorch 模型导出为标准模型表示的方法,旨在在不同的(即无 Python 的)环境中运行。官方文档可以在 这里 找到。
在本教程中,您将学习如何使用torch.export()
从PyTorch程序中提取ExportedProgram
(即单图表示)。我们还将详细介绍为了使您的模型与torch.export
兼容,您可能需要进行的一些考虑/修改。
目录
基本用法¶
torch.export
通过跟踪目标函数并给定示例输入,从 PyTorch 程序中提取单图表示。
torch.export.export()
是 torch.export
的主要入口点。
在本教程中,torch.export
和 torch.export.export()
实际上是同义词,
尽管 torch.export
通常指的是 PyTorch 2.X 的导出过程,而 torch.export.export()
通常指的是实际的函数调用。
torch.export.export()
的签名是:
export(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram
torch.export.export()
通过调用 f(*args, **kwargs)
来追踪张量计算图,并将其包装在一个 ExportedProgram
中,该程序可以序列化或稍后使用不同的输入执行。请注意,虽然输出的 ExportedGraph
是可调用的,并且可以像原始输入可调用对象一样调用,但它不是 torch.nn.Module
。我们将在教程的后面详细讨论 dynamic_shapes
参数。
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
让我们回顾一下ExportedProgram
中一些值得关注的属性。
graph
属性是从我们导出的函数中追踪得到的 FX graph,即所有 PyTorch 操作的计算图。FX 图具有一些重要的属性:
这些操作是“ATen级别”的操作。
该图是“功能化的”,意味着没有操作是突变。
graph_module
属性是包装 graph
属性的 GraphModule
,以便它可以作为 torch.nn.Module
运行。
print(exported_mod)
print(exported_mod.graph_module)
打印的代码显示,FX图仅包含ATen级别的操作(例如torch.ops.aten
),并且突变已被移除。例如,突变操作torch.nn.functional.relu(..., inplace=True)
在打印的代码中由torch.ops.aten.relu.default
表示,该操作不会突变。原始突变relu
操作的输入的未来使用被替换为非突变relu
操作的额外新输出所取代。
ExportedProgram
中其他感兴趣的属性包括:
graph_signature
– 导出图的输入、输出、参数、缓冲区等。range_constraints
– 约束条件,稍后介绍
print(exported_mod.graph_signature)
请参阅torch.export
文档以获取更多详细信息。
Graph Breaks¶
尽管torch.export
与torch.compile
共享组件,
torch.export
的关键限制,特别是与torch.compile
相比,
是它不支持图中断。这是因为处理图中断涉及使用默认的Python评估来解释不支持的操作,
这与导出用例不兼容。因此,为了使您的模型代码与torch.export
兼容,
您需要修改代码以移除图中断。
在以下情况下需要图形中断:
数据依赖的控制流
class Bad1(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
import traceback as tb
try:
export(Bad1(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
使用
.data
访问张量数据
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
try:
export(Bad2(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
调用不受支持的函数(例如许多内置函数)
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
try:
export(Bad3(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
不支持的Python语言特性(例如抛出异常,匹配语句)
class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
try:
export(Bad4(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
非严格导出¶
为了追踪程序,torch.export
使用 TorchDynamo,一个字节码分析引擎,来符号化分析 Python 代码并根据结果构建图。这种分析允许 torch.export
提供更强的安全性保证,但并非所有 Python 代码都受支持,导致这些图的中断。
为了解决这个问题,在PyTorch 2.3中,我们引入了一种新的导出模式,称为非严格模式,在这种模式下,我们使用Python解释器跟踪程序,就像在急切模式下一样执行它,允许我们跳过不支持的Python特性。这是通过添加一个strict=False
标志来实现的。
查看一些导致图形中断的先前示例:
使用
.data
访问张量数据现在可以正常工作
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))
调用不受支持的函数(例如许多内置函数)跟踪
通过,但在这种情况下,id(x)
在图中被特化为一个常量整数。这是因为 id(x)
不是一个张量操作,因此该操作不会被记录在图中。
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
不支持的Python语言特性(例如抛出异常,匹配
语句现在也会被追踪。
class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))
然而,仍有一些功能需要对原始模块进行重写:
控制流操作¶
torch.export
实际上确实支持数据依赖的控制流。
但这些需要使用控制流操作来表达。例如,
我们可以使用 cond
操作来修复上面的控制流示例,如下所示:
from functorch.experimental.control_flow import cond
class Bad1Fixed(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return cond(x.sum() > 0, true_fn, false_fn, [x])
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
有一些关于cond
的限制需要注意:
谓词(即
x.sum() > 0
)必须返回一个布尔值或单元素张量。操作数(即
[x]
)必须是张量。分支函数(即
true_fn
和false_fn
)的签名必须与操作数匹配,并且它们都必须返回具有相同元数据的单个张量(例如,dtype
、shape
等)。分支函数不能改变输入或全局变量。
分支函数无法访问闭包变量,除非函数在方法的范围内定义,此时可以访问
self
。
有关cond
的更多详细信息,请查看cond文档。
约束/动态形状¶
操作可以针对不同的张量形状有不同的特殊化/行为,因此默认情况下,
torch.export
要求输入到 ExportedProgram
的张量形状与初始 torch.export.export()
调用中给出的相应示例输入相同。
如果我们尝试在下面的示例中使用不同形状的张量运行 ExportedProgram
,我们会得到一个错误:
class MyModule2(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod2 = MyModule2()
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
try:
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
tb.print_exc()
我们可以使用torch.export.export()
的dynamic_shapes
参数来放宽这个约束,它允许我们使用torch.export.Dim
(文档)来指定输入张量的哪些维度是动态的。
对于输入可调用对象的每个张量参数,我们可以指定从维度到torch.export.Dim
的映射。
torch.export.Dim
本质上是一个命名的符号整数,具有可选的最小和最大边界。
然后,torch.export.export()
的dynamic_shapes
参数的格式是从输入可调用对象的张量参数名称到维度 –> 维度映射的映射,如上所述。
如果没有为张量参数的维度提供torch.export.Dim
,则该维度假定为静态。
torch.export.Dim
的第一个参数是符号整数的名称,用于调试。
然后我们可以指定一个可选的最小和最大边界(包括边界)。下面,我们展示一个使用示例。
在下面的示例中,我们的输入
inp1
的第一个维度没有约束,但第二个维度的大小必须在区间 [4, 18] 内。
from torch.export import Dim
inp1 = torch.randn(10, 10, 2)
class DynamicShapesExample1(torch.nn.Module):
def forward(self, x):
x = x[:, 2:]
return torch.relu(x)
inp1_dim0 = Dim("inp1_dim0")
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
"x": {0: inp1_dim0, 1: inp1_dim1},
}
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
except Exception:
tb.print_exc()
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
except Exception:
tb.print_exc()
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
except Exception:
tb.print_exc()
请注意,如果我们的示例输入到torch.export
不满足dynamic_shapes
给出的约束条件,那么我们会得到一个错误。
inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
dynamic_shapes1_bad = {
"x": {0: inp1_dim0, 1: inp1_dim1_bad},
}
try:
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
except Exception:
tb.print_exc()
我们可以通过使用相同的torch.export.Dim
对象来强制不同张量的维度相等,例如,在矩阵乘法中:
inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)
class DynamicShapesExample2(torch.nn.Module):
def forward(self, x, y):
return x @ y
inp2_dim0 = Dim("inp2_dim0")
inner_dim = Dim("inner_dim")
inp3_dim1 = Dim("inp3_dim1")
dynamic_shapes2 = {
"x": {0: inp2_dim0, 1: inner_dim},
"y": {0: inner_dim, 1: inp3_dim1},
}
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))
try:
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
tb.print_exc()
我们也可以用其他维度来描述一个维度。在详细说明一个维度时有一些限制,但通常,形式为A * Dim + B
的表达式应该是可行的。
class DerivedDimExample1(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
foo = DerivedDimExample1()
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
try:
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
except Exception:
tb.print_exc()
class DerivedDimExample2(torch.nn.Module):
def forward(self, z, y):
return z[1:] + y[1::3]
foo = DerivedDimExample2()
z, y = torch.randn(4), torch.randn(10)
dx = torch.export.Dim("dx", min=3, max=6)
dz = dx + 1
dy = dx * 3 + 1
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
我们实际上可以使用torch.export
来指导我们哪些dynamic_shapes
约束是必要的。我们可以通过放松所有约束(回想一下,如果我们不为维度提供约束,默认行为是约束到示例输入的确切形状值)并让torch.export
出错来实现这一点。
inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)
class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
dynamic_shapes3 = {
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
}
try:
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
except Exception:
tb.print_exc()
我们可以看到,错误信息为我们提供了关于动态形状约束的建议修复方法。让我们按照这些建议进行操作(具体建议可能略有不同):
def suggested_fixes():
inp4_dim1 = Dim('shared_dim')
# suggested fixes below
inp4_dim0 = Dim('inp4_dim0', max=16)
inp5_dim1 = Dim('inp5_dim1', min=17)
inp5_dim0 = inp4_dim1
# end of suggested fixes
return {
"x": {0: inp4_dim0, 1: inp4_dim1},
"y": {0: inp5_dim0, 1: inp5_dim1},
}
dynamic_shapes3_fixed = suggested_fixes()
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))
请注意,在上面的例子中,因为我们在dynamic_shapes_example3
中限制了x.shape[0]
的值,所以即使有一个原始的if
语句,导出的程序也是合理的。
如果你想了解为什么torch.export
生成了这些约束,你可以
使用环境变量TORCH_LOGS=dynamic,dynamo
重新运行脚本,
或者使用torch._logging.set_logs
。
import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
我们可以使用range_constraints
字段查看ExportedProgram
的符号形状范围。
print(exported_dynamic_shapes_example3.range_constraints)
自定义操作¶
torch.export
可以导出带有自定义操作符的 PyTorch 程序。
目前,注册自定义操作以供torch.export
使用的步骤如下:
使用
torch.library
定义自定义操作(参考),就像定义其他自定义操作一样
@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)
定义一个自定义操作的
"Meta"
实现,该实现返回一个与预期输出形状相同的空张量
@custom_op.register_fake
def custom_op_meta(x):
return torch.empty_like(x)
从你想要导出的代码中调用自定义操作,使用
torch.ops
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
像以前一样导出代码
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example.module()(torch.randn(3, 3)))
注意在上面的输出中,自定义操作包含在导出的图中。
当我们调用导出的图作为函数时,原始的自定义操作被调用,
如print
调用所示。
如果您在C++中实现了一个自定义操作符,请参考
此文档
使其与torch.export
兼容。
分解¶
默认情况下,由torch.export
生成的图返回一个仅包含功能性ATen操作符的图。这个功能性ATen操作符集(或“操作集”)包含大约2000个操作符,所有这些操作符都是功能性的,也就是说,它们不会改变或别名输入。你可以在这里here找到所有ATen操作符的列表,并且你可以通过检查op._schema.is_mutable
来检查一个操作符是否是功能性的,例如:
print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
默认情况下,您希望运行导出图的环境应支持所有这些约2000个操作符。但是,如果您的特定环境只能支持约2000个操作符的子集,您可以在导出的程序上使用以下API。
def run_decompositions(
self: ExportedProgram,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
) -> ExportedProgram
run_decompositions
接受一个分解表,这是一个将运算符映射到指定如何将该运算符减少或分解为其他ATen运算符的等效序列的函数。
run_decompositions
的默认分解表是
Core ATen decomposition table
,它将把所有 ATen 操作符分解为
Core ATen Operator Set
,该集合仅包含约 180 个操作符。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)
core_ir_ep = ep.run_decompositions()
print(core_ir_ep.graph)
请注意,在运行run_decompositions
后,
torch.ops.aten.t.default
操作符,它不属于核心ATen操作集,
已被替换为torch.ops.aten.permute.default
,这是核心ATen操作集的一部分。
大多数ATen操作符已经有分解,这些分解位于这里。如果你想使用一些现有的分解函数,你可以将你想要分解的操作符列表传递给get_decompositions函数,该函数将使用现有的分解实现返回一个分解表。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)
from torch._decomp import get_decompositions
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
core_ir_ep = ep.run_decompositions(decomp_table)
print(core_ir_ep.graph)
如果对于你想要分解的ATen操作符没有现有的分解函数,欢迎向PyTorch提交一个实现分解的拉取请求!
ExportDB¶
torch.export
只会从 PyTorch 程序中导出一个单一的计算图。由于这一要求,某些 Python 或 PyTorch 功能将无法与 torch.export
兼容,这将需要用户重写部分模型代码。我们在教程的前面已经看到了这样的例子——例如,使用 cond
重写 if 语句。
ExportDB 是记录支持和不支持的 Python/PyTorch 功能的标准参考文档,适用于 torch.export
。它本质上是一个程序示例列表,每个示例代表一个特定的 Python/PyTorch 功能的使用及其与 torch.export
的交互。示例还按类别标记,以便更容易搜索。
例如,让我们使用ExportDB来更好地理解谓词在cond
操作符中的工作原理。
我们可以查看名为cond_predicate
的示例,它有一个torch.cond
标签。示例代码如下:
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
- ``torch.Tensor`` with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
更一般地说,当以下情况之一发生时,ExportDB 可以用作参考:
在尝试
torch.export
之前,您已经知道您的模型使用了一些复杂的Python/PyTorch功能,并且您想知道torch.export
是否支持该功能。当尝试
torch.export
时,出现了失败,不清楚如何解决它。
ExportDB 并不详尽,但旨在涵盖典型 PyTorch 代码中的所有用例。如果有重要的 Python/PyTorch 功能应添加到 ExportDB 或由 torch.export
支持,请随时联系我们。
运行导出的程序¶
由于torch.export
仅是一种图捕获机制,急切地调用由torch.export
生成的工件将等同于运行急切模块。为了优化导出程序的执行,我们可以通过torch.compile
将此导出的工件传递给后端,如Inductor,AOTInductor,或TensorRT。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))
# Run it eagerly
res = ep.module()(inp)
print(res)
# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
import torch._export
import torch._inductor
# Note: these APIs are subject to change
# Compile the exported program to a .so using ``AOTInductor``
with torch.no_grad():
so_path = torch._inductor.aot_compile(ep.module(), [inp])
# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
res = torch._export.aot_load(so_path, device="cuda")(inp)