Shortcuts

torch.export 教程

创建于:2023年10月02日 | 最后更新:2024年7月22日 | 最后验证:2024年11月05日

作者: William Wen, Zhengxu Chen, Angela Yi

警告

torch.export 及其相关功能目前处于原型状态,可能会发生向后兼容性的重大更改。本教程提供了截至 PyTorch 2.3 版本的 torch.export 使用情况的快照。

torch.export() 是 PyTorch 2.X 中将 PyTorch 模型导出为标准模型表示的方法,旨在在不同的(即无 Python 的)环境中运行。官方文档可以在 这里 找到。

在本教程中,您将学习如何使用torch.export()从PyTorch程序中提取ExportedProgram(即单图表示)。我们还将详细介绍为了使您的模型与torch.export兼容,您可能需要进行的一些考虑/修改。

目录

基本用法

torch.export 通过跟踪目标函数并给定示例输入,从 PyTorch 程序中提取单图表示。 torch.export.export()torch.export 的主要入口点。

在本教程中,torch.exporttorch.export.export() 实际上是同义词, 尽管 torch.export 通常指的是 PyTorch 2.X 的导出过程,而 torch.export.export() 通常指的是实际的函数调用。

torch.export.export() 的签名是:

export(
    f: Callable,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 通过调用 f(*args, **kwargs) 来追踪张量计算图,并将其包装在一个 ExportedProgram 中,该程序可以序列化或稍后使用不同的输入执行。请注意,虽然输出的 ExportedGraph 是可调用的,并且可以像原始输入可调用对象一样调用,但它不是 torch.nn.Module。我们将在教程的后面详细讨论 dynamic_shapes 参数。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))

让我们回顾一下ExportedProgram中一些值得关注的属性。

graph 属性是从我们导出的函数中追踪得到的 FX graph,即所有 PyTorch 操作的计算图。FX 图具有一些重要的属性:

  • 这些操作是“ATen级别”的操作。

  • 该图是“功能化的”,意味着没有操作是突变。

graph_module 属性是包装 graph 属性的 GraphModule,以便它可以作为 torch.nn.Module 运行。

print(exported_mod)
print(exported_mod.graph_module)

打印的代码显示,FX图仅包含ATen级别的操作(例如torch.ops.aten),并且突变已被移除。例如,突变操作torch.nn.functional.relu(..., inplace=True)在打印的代码中由torch.ops.aten.relu.default表示,该操作不会突变。原始突变relu操作的输入的未来使用被替换为非突变relu操作的额外新输出所取代。

ExportedProgram 中其他感兴趣的属性包括:

  • graph_signature – 导出图的输入、输出、参数、缓冲区等。

  • range_constraints – 约束条件,稍后介绍

print(exported_mod.graph_signature)

请参阅torch.export 文档以获取更多详细信息。

Graph Breaks

尽管torch.exporttorch.compile共享组件, torch.export的关键限制,特别是与torch.compile相比, 是它不支持图中断。这是因为处理图中断涉及使用默认的Python评估来解释不支持的操作, 这与导出用例不兼容。因此,为了使您的模型代码与torch.export兼容, 您需要修改代码以移除图中断。

在以下情况下需要图形中断:

  • 数据依赖的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 使用.data访问张量数据

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 调用不受支持的函数(例如许多内置函数)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 不支持的Python语言特性(例如抛出异常,匹配语句)

class Bad4(torch.nn.Module):
    def forward(self, x):
        try:
            x = x + 1
            raise RuntimeError("bad")
        except:
            x = x + 2
        return x

try:
    export(Bad4(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()

非严格导出

为了追踪程序,torch.export 使用 TorchDynamo,一个字节码分析引擎,来符号化分析 Python 代码并根据结果构建图。这种分析允许 torch.export 提供更强的安全性保证,但并非所有 Python 代码都受支持,导致这些图的中断。

为了解决这个问题,在PyTorch 2.3中,我们引入了一种新的导出模式,称为非严格模式,在这种模式下,我们使用Python解释器跟踪程序,就像在急切模式下一样执行它,允许我们跳过不支持的Python特性。这是通过添加一个strict=False标志来实现的。

查看一些导致图形中断的先前示例:

  • 使用.data访问张量数据现在可以正常工作

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))
  • 调用不受支持的函数(例如许多内置函数)跟踪

通过,但在这种情况下,id(x) 在图中被特化为一个常量整数。这是因为 id(x) 不是一个张量操作,因此该操作不会被记录在图中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
  • 不支持的Python语言特性(例如抛出异常,匹配

语句现在也会被追踪。

class Bad4(torch.nn.Module):
    def forward(self, x):
        try:
            x = x + 1
            raise RuntimeError("bad")
        except:
            x = x + 2
        return x

bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))

然而,仍有一些功能需要对原始模块进行重写:

控制流操作

torch.export 实际上确实支持数据依赖的控制流。 但这些需要使用控制流操作来表达。例如, 我们可以使用 cond 操作来修复上面的控制流示例,如下所示:

from functorch.experimental.control_flow import cond

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))

有一些关于cond的限制需要注意:

  • 谓词(即 x.sum() > 0)必须返回一个布尔值或单元素张量。

  • 操作数(即 [x])必须是张量。

  • 分支函数(即 true_fnfalse_fn)的签名必须与操作数匹配,并且它们都必须返回具有相同元数据的单个张量(例如,dtypeshape 等)。

  • 分支函数不能改变输入或全局变量。

  • 分支函数无法访问闭包变量,除非函数在方法的范围内定义,此时可以访问self

有关cond的更多详细信息,请查看cond文档

约束/动态形状

操作可以针对不同的张量形状有不同的特殊化/行为,因此默认情况下, torch.export 要求输入到 ExportedProgram 的张量形状与初始 torch.export.export() 调用中给出的相应示例输入相同。 如果我们尝试在下面的示例中使用不同形状的张量运行 ExportedProgram,我们会得到一个错误:

class MyModule2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod2 = MyModule2()
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))

try:
    exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
    tb.print_exc()

我们可以使用torch.export.export()dynamic_shapes参数来放宽这个约束,它允许我们使用torch.export.Dim文档)来指定输入张量的哪些维度是动态的。

对于输入可调用对象的每个张量参数,我们可以指定从维度到torch.export.Dim的映射。 torch.export.Dim本质上是一个命名的符号整数,具有可选的最小和最大边界。

然后,torch.export.export()dynamic_shapes参数的格式是从输入可调用对象的张量参数名称到维度 –> 维度映射的映射,如上所述。 如果没有为张量参数的维度提供torch.export.Dim,则该维度假定为静态。

torch.export.Dim的第一个参数是符号整数的名称,用于调试。 然后我们可以指定一个可选的最小和最大边界(包括边界)。下面,我们展示一个使用示例。

在下面的示例中,我们的输入 inp1 的第一个维度没有约束,但第二个维度的大小必须在区间 [4, 18] 内。

from torch.export import Dim

inp1 = torch.randn(10, 10, 2)

class DynamicShapesExample1(torch.nn.Module):
    def forward(self, x):
        x = x[:, 2:]
        return torch.relu(x)

inp1_dim0 = Dim("inp1_dim0")
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
    "x": {0: inp1_dim0, 1: inp1_dim1},
}

exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)

print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
except Exception:
    tb.print_exc()

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
except Exception:
    tb.print_exc()

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
except Exception:
    tb.print_exc()

请注意,如果我们的示例输入到torch.export不满足dynamic_shapes给出的约束条件,那么我们会得到一个错误。

inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
dynamic_shapes1_bad = {
    "x": {0: inp1_dim0, 1: inp1_dim1_bad},
}

try:
    export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
except Exception:
    tb.print_exc()

我们可以通过使用相同的torch.export.Dim对象来强制不同张量的维度相等,例如,在矩阵乘法中:

inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)

class DynamicShapesExample2(torch.nn.Module):
    def forward(self, x, y):
        return x @ y

inp2_dim0 = Dim("inp2_dim0")
inner_dim = Dim("inner_dim")
inp3_dim1 = Dim("inp3_dim1")

dynamic_shapes2 = {
    "x": {0: inp2_dim0, 1: inner_dim},
    "y": {0: inner_dim, 1: inp3_dim1},
}

exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)

print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))

try:
    exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
    tb.print_exc()

我们也可以用其他维度来描述一个维度。在详细说明一个维度时有一些限制,但通常,形式为A * Dim + B的表达式应该是可行的。

class DerivedDimExample1(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

foo = DerivedDimExample1()

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})

derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)

print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))

try:
    derived_dim_example1.module()(torch.randn(4), torch.randn(6))
except Exception:
    tb.print_exc()


class DerivedDimExample2(torch.nn.Module):
    def forward(self, z, y):
        return z[1:] + y[1::3]

foo = DerivedDimExample2()

z, y = torch.randn(4), torch.randn(10)
dx = torch.export.Dim("dx", min=3, max=6)
dz = dx + 1
dy = dx * 3 + 1
derived_dynamic_shapes2 = ({0: dz}, {0: dy})

derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))

我们实际上可以使用torch.export来指导我们哪些dynamic_shapes约束是必要的。我们可以通过放松所有约束(回想一下,如果我们不为维度提供约束,默认行为是约束到示例输入的确切形状值)并让torch.export出错来实现这一点。

inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)

class DynamicShapesExample3(torch.nn.Module):
    def forward(self, x, y):
        if x.shape[0] <= 16:
            return x @ y[:, :16]
        return y

dynamic_shapes3 = {
    "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
    "y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
}

try:
    export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
except Exception:
    tb.print_exc()

我们可以看到,错误信息为我们提供了关于动态形状约束的建议修复方法。让我们按照这些建议进行操作(具体建议可能略有不同):

def suggested_fixes():
    inp4_dim1 = Dim('shared_dim')
    # suggested fixes below
    inp4_dim0 = Dim('inp4_dim0', max=16)
    inp5_dim1 = Dim('inp5_dim1', min=17)
    inp5_dim0 = inp4_dim1
    # end of suggested fixes
    return {
        "x": {0: inp4_dim0, 1: inp4_dim1},
        "y": {0: inp5_dim0, 1: inp5_dim1},
    }

dynamic_shapes3_fixed = suggested_fixes()
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))

请注意,在上面的例子中,因为我们在dynamic_shapes_example3中限制了x.shape[0]的值,所以即使有一个原始的if语句,导出的程序也是合理的。

如果你想了解为什么torch.export生成了这些约束,你可以 使用环境变量TORCH_LOGS=dynamic,dynamo重新运行脚本, 或者使用torch._logging.set_logs

import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)

# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)

我们可以使用range_constraints字段查看ExportedProgram的符号形状范围。

print(exported_dynamic_shapes_example3.range_constraints)

自定义操作

torch.export 可以导出带有自定义操作符的 PyTorch 程序。

目前,注册自定义操作以供torch.export使用的步骤如下:

  • 使用torch.library定义自定义操作(参考),就像定义其他自定义操作一样

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)
  • 定义一个自定义操作的"Meta"实现,该实现返回一个与预期输出形状相同的空张量

@custom_op.register_fake
def custom_op_meta(x):
    return torch.empty_like(x)
  • 从你想要导出的代码中调用自定义操作,使用 torch.ops

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x
  • 像以前一样导出代码

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example.module()(torch.randn(3, 3)))

注意在上面的输出中,自定义操作包含在导出的图中。 当我们调用导出的图作为函数时,原始的自定义操作被调用, 如print调用所示。

如果您在C++中实现了一个自定义操作符,请参考 此文档 使其与torch.export兼容。

分解

默认情况下,由torch.export生成的图返回一个仅包含功能性ATen操作符的图。这个功能性ATen操作符集(或“操作集”)包含大约2000个操作符,所有这些操作符都是功能性的,也就是说,它们不会改变或别名输入。你可以在这里here找到所有ATen操作符的列表,并且你可以通过检查op._schema.is_mutable来检查一个操作符是否是功能性的,例如:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)

默认情况下,您希望运行导出图的环境应支持所有这些约2000个操作符。但是,如果您的特定环境只能支持约2000个操作符的子集,您可以在导出的程序上使用以下API。

def run_decompositions(
    self: ExportedProgram,
    decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
) -> ExportedProgram

run_decompositions 接受一个分解表,这是一个将运算符映射到指定如何将该运算符减少或分解为其他ATen运算符的等效序列的函数。

run_decompositions 的默认分解表是 Core ATen decomposition table ,它将把所有 ATen 操作符分解为 Core ATen Operator Set ,该集合仅包含约 180 个操作符。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)

ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)

core_ir_ep = ep.run_decompositions()
print(core_ir_ep.graph)

请注意,在运行run_decompositions后, torch.ops.aten.t.default操作符,它不属于核心ATen操作集, 已被替换为torch.ops.aten.permute.default,这是核心ATen操作集的一部分。

大多数ATen操作符已经有分解,这些分解位于这里。如果你想使用一些现有的分解函数,你可以将你想要分解的操作符列表传递给get_decompositions函数,该函数将使用现有的分解实现返回一个分解表。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)

ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)

from torch._decomp import get_decompositions
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
core_ir_ep = ep.run_decompositions(decomp_table)
print(core_ir_ep.graph)

如果对于你想要分解的ATen操作符没有现有的分解函数,欢迎向PyTorch提交一个实现分解的拉取请求!

ExportDB

torch.export 只会从 PyTorch 程序中导出一个单一的计算图。由于这一要求,某些 Python 或 PyTorch 功能将无法与 torch.export 兼容,这将需要用户重写部分模型代码。我们在教程的前面已经看到了这样的例子——例如,使用 cond 重写 if 语句。

ExportDB 是记录支持和不支持的 Python/PyTorch 功能的标准参考文档,适用于 torch.export。它本质上是一个程序示例列表,每个示例代表一个特定的 Python/PyTorch 功能的使用及其与 torch.export 的交互。示例还按类别标记,以便更容易搜索。

例如,让我们使用ExportDB来更好地理解谓词在cond操作符中的工作原理。 我们可以查看名为cond_predicate的示例,它有一个torch.cond标签。示例代码如下:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更一般地说,当以下情况之一发生时,ExportDB 可以用作参考:

  1. 在尝试torch.export之前,您已经知道您的模型使用了一些复杂的Python/PyTorch功能,并且您想知道torch.export是否支持该功能。

  2. 当尝试torch.export时,出现了失败,不清楚如何解决它。

ExportDB 并不详尽,但旨在涵盖典型 PyTorch 代码中的所有用例。如果有重要的 Python/PyTorch 功能应添加到 ExportDB 或由 torch.export 支持,请随时联系我们。

运行导出的程序

由于torch.export仅是一种图捕获机制,急切地调用由torch.export生成的工件将等同于运行急切模块。为了优化导出程序的执行,我们可以通过torch.compile将此导出的工件传递给后端,如Inductor,AOTInductor,或TensorRT

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
import torch._export
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a .so using ``AOTInductor``
with torch.no_grad():
so_path = torch._inductor.aot_compile(ep.module(), [inp])

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
res = torch._export.aot_load(so_path, device="cuda")(inp)

结论

我们介绍了torch.export,这是PyTorch 2.X中从PyTorch程序中导出单一计算图的新方法。特别是,我们展示了几种代码修改和考虑(控制流操作、约束等),这些是导出图时需要进行的。

脚本总运行时间: ( 0 分钟 0.000 秒)

Gallery generated by Sphinx-Gallery

优云智算