Shortcuts

基于TorchScript的ONNX导出器

注意

要使用 TorchDynamo 而不是 TorchScript 导出 ONNX 模型,请参阅 torch.onnx.dynamo_export()

示例:从 PyTorch 到 ONNX 的 AlexNet

这里是一个简单的脚本,它将预训练的AlexNet导出为一个名为alexnet.onnx的ONNX文件。 调用torch.onnx.export运行模型一次以跟踪其执行,然后将跟踪的模型导出到指定的文件中:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# 提供输入和输出名称以设置模型图中值的显示名称。设置这些名称不会改变图的语义,仅用于提高可读性。
# 网络的输入包括输入的扁平列表(即你传递给forward()方法的值),后面是参数的扁平列表。你可以部分指定名称,即提供一个比模型输入数量短的列表,我们将只设置从开始的那部分名称。
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区, 其中包含您导出的模型的网络结构和参数(在本例中为 AlexNet)。参数 verbose=True 使导出器打印出模型的人类可读表示:

```html
# 这些是网络的输入和参数,它们采用了我们之前指定的名称。
# 我们之前指定的名称。
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- 为简洁起见省略 ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # 每个语句由一些输出张量(及其类型)组成,
  # 要运行的操作符(及其属性,例如内核、步幅等),
  # 其输入张量(%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- 为简洁起见省略 ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # 动态意味着形状是未知的。这可能是因为我们的实现存在限制(我们希望在未来的版本中修复),
  # 或者是真正的动态形状。
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- 为简洁起见省略 ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/

您还可以使用ONNX库验证输出, 您可以使用pip安装该库:

pip install onnx

然后,你可以运行:

import onnx

# 加载 ONNX 模型
model = onnx.load("alexnet.onnx")

# 检查模型是否结构良好
onnx.checker.check_model(model)

# 打印图的人类可读表示
print(onnx.helper.printable_graph(model.graph))

您还可以使用众多支持ONNX的运行时之一运行导出的模型。例如,在安装ONNX Runtime之后,您可以加载并运行模型:

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这里有一个更深入的教程,关于导出模型并在ONNX Runtime中运行它

追踪与脚本

在内部,torch.onnx.export() 需要一个 torch.jit.ScriptModule 而不是 一个 torch.nn.Module。如果传入的模型不是一个 ScriptModuleexport() 将使用 tracing 将其转换为一个:

  • 跟踪: 如果使用一个不是已经是ScriptModule的模块调用torch.onnx.export(),它首先会执行与torch.jit.trace()等效的操作,即使用给定的args执行模型一次,并记录执行期间发生的所有操作。这意味着如果你的模型是动态的,例如,根据输入数据改变行为,导出的模型将不会捕捉到这种动态行为。 我们建议检查导出的模型,并确保操作符看起来合理。跟踪将展开循环和if语句,导出一个与跟踪运行完全相同的静态图。如果你想用动态控制流导出模型,你需要使用脚本

  • 脚本编写:通过脚本编译模型保留了动态控制流,并且适用于不同大小的输入。要使用脚本编写:

    • 使用 torch.jit.script() 生成一个 ScriptModule

    • 使用 ScriptModule 作为模型调用 torch.onnx.export()。仍然需要 args,但它们将仅在内部用于生成示例输出,以便捕获输出的类型和形状。不会执行跟踪。

请参阅TorchScript 简介TorchScript 了解更多详情,包括如何组合追踪和脚本来适应不同模型的特定需求。

避免陷阱

避免使用NumPy和内置的Python类型

PyTorch模型可以使用NumPy或Python类型和函数编写,但在追踪过程中,任何NumPy或Python类型的变量(而不是torch.Tensor)都会被转换为常量,如果这些值应根据输入而变化,则会产生错误的结果。

例如,而不是在 numpy.ndarrays 上使用 numpy 函数:

# 不好!在跟踪过程中将被常量替换。
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 操作符:

# 很好!张量操作将在跟踪过程中被捕获。
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

而不是使用 torch.Tensor.item()(它将一个张量转换为Python内置数字):

# 不好!y.item() 在追踪过程中将被替换为一个常量。
def forward(self, x, y):
    return x.reshape(y.item(), -1)

使用 torch 对单元素张量的隐式转换支持:

# 很好!y 将在追踪过程中作为变量保留。
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生不正确的跟踪,从而导致不正确的 ONNX 图。请改用 torch.Tensor.detach()。(正在努力 完全移除 Tensor.data)。

在使用tensor.shape进行跟踪模式时避免就地操作

在追踪模式下,从 tensor.shape 获得的形状被追踪为张量,并且共享相同的内存。这可能会导致最终输出值的不匹配。作为一种解决方法,在这些情况下避免使用就地操作。例如,在模型中:

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追踪模式下共享相同的内存。 这可以通过重写就地操作来避免:

real_seq_length = real_seq_length + 2

限制

类型

  • 仅支持 torch.Tensors、可以简单转换为 torch.Tensors 的数值类型(例如 float、int)以及这些类型的元组和列表作为模型输入或输出。在 tracing 模式下,接受 Dict 和 str 输入和输出,但:

    • 任何依赖于字典或字符串输入值的计算将被替换为在单次追踪执行期间看到的常量值

    • 任何输出如果是字典类型,将会被静默替换为其值的扁平序列(键将被移除)。例如,{"foo": 1, "bar": 2} 变为 (1, 2)

    • 任何输出如果是字符串类型,将会被静默移除。

  • 由于ONNX对嵌套序列的支持有限,在脚本模式下,不支持涉及元组和列表的某些操作。 特别是,不支持将元组附加到列表中。在跟踪模式下,嵌套序列将在跟踪期间自动展平。

操作符实现中的差异

由于操作符实现的不同,在不同的运行时上运行导出的模型可能会产生不同的结果,这些结果可能与PyTorch的结果不同。通常这些差异在数值上是小的,因此只有当您的应用程序对这些小的差异敏感时,这才应该是一个问题。

不支持的张量索引模式

无法导出的张量索引模式列在下面。 如果您在导出不包含以下任何不支持模式的模型时遇到问题,请仔细检查您是否正在使用最新的opset_version进行导出。

读取 / 获取

当对张量进行读取索引时,不支持以下模式:

# 包含负值的张量索引。
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# 解决方法:使用正值索引。

写入 / 设置

当对张量进行写操作时,不支持以下模式:

# 如果有任何张量索引的秩 >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# 解决方法:使用秩 >= 2 的单个张量索引,
#              或使用秩 == 1 的多个连续张量索引。

# 多个不连续的张量索引
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# 解决方法:转置 `data` 使得张量索引是连续的。

# 包含负值的张量索引
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# 解决方法:使用正值索引。

# 需要对 new_data 进行隐式广播
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# 解决方法:显式扩展 new_data。
# 示例:
#   data 形状: [3, 4, 5]
#   new_data 形状: [5]
#   广播后的预期 new_data 形状: [2, 2, 2, 5]

添加对操作符的支持

当导出一个包含不支持操作符的模型时,您会看到类似以下的错误消息:

RuntimeError: ONNX 导出失败: 无法导出操作符 foo

当这种情况发生时,你可以采取以下几种措施:

  1. 更改模型以不使用该操作符。

  2. 创建一个符号函数以转换操作符并将其注册为自定义符号函数。

  3. 为 PyTorch 贡献代码,以将相同的符号函数添加到 torch.onnx 本身。

如果你决定实现一个符号函数(我们希望你能将其贡献回PyTorch!),以下是你如何开始的方法:

ONNX 导出器内部机制

“符号函数”是一种将 PyTorch 操作符分解为一系列 ONNX 操作符组合的函数。

在导出过程中,TorchScript 图中的每个节点(包含一个 PyTorch 操作符)都会按照拓扑顺序被导出器访问。 访问一个节点时,导出器会查找为该操作符注册的符号函数。符号函数是用 Python 实现的。一个名为 foo 的操作符的符号函数看起来像这样:

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  通过调用 `g.op()` 更新图 g,添加表示此 PyTorch 函数的 ONNX 操作。

  参数:
    g (Graph): 要写入 ONNX 表示的图。
    input_0 (Value): 表示包含此操作符第一个输入的变量的值。
    input_1 (Value): 表示包含此操作符第二个输入的变量的值。

  返回:
    一个 Value 或 Values 列表,指定计算与给定输入的原始 PyTorch 操作符等效的 ONNX 节点。

    如果无法转换为 ONNX,则返回 None。
  """
  ...

The torch._C 类型是围绕在 C++ 中定义的类型的 Python 包装器,这些类型定义在 ir.h 中。

添加符号函数的过程取决于操作符的类型。

ATen 运算符

ATen 是 PyTorch 的内置张量库。 如果该操作符是一个 ATen 操作符(在 TorchScript 图中以 aten:: 前缀显示),请确保它尚未被支持。

支持的运算符列表

访问自动生成的支持的TorchScript操作符列表 以了解每个opset_version中支持的操作符详情。

添加对aten或量化操作符的支持

如果操作符不在上述列表中:

  • torch/onnx/symbolic_opset.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保函数名称与 ATen 函数相同,该函数可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件是在构建时生成的,因此在您构建 PyTorch 之前不会出现在您的检出中)。

  • 默认情况下,第一个参数是ONNX图。 其他参数名称必须与.pyi文件中的名称完全匹配, 因为分派是通过关键字参数完成的。

  • 在符号函数中,如果运算符在 ONNX 标准运算符集中, 我们只需要创建一个节点来表示图中的 ONNX 运算符。 如果不是,我们可以组合几个具有等效语义的标准运算符来表示 ATen 运算符。

这是一个处理ELU操作符缺失符号函数的示例。

如果我们运行以下代码:

print(
    torch.jit.trace(
        torch.nn.ELU(), # 模块
        torch.ones(1)   # 示例输入
    ).graph
)

我们看到了类似的内容:

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图中看到 aten::elu,我们知道这是一个ATen操作符。

我们检查了ONNX操作符列表, 并确认Elu在ONNX中是标准化的。

我们在 torch/nn/functional.pyi 中找到了 elu 的签名:

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们在 symbolic_opset9.py 中添加了以下几行:

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在 PyTorch 能够导出包含 aten::elu 运算符的模型了!

参见 torch/onnx/symbolic_opset*.py 文件以获取更多示例。

torch.autograd.Functions

如果操作符是 torch.autograd.Function 的子类,有三种方式可以导出它。

静态符号方法

您可以在函数类中添加一个名为 symbolic 的静态方法。它应该返回 ONNX 操作符,这些操作符表示函数在 ONNX 中的行为。例如:

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联自动求导函数

在未为其后续的 torch.autograd.Function 提供静态符号方法的情况下, 或者在未提供将 prim::PythonOp 注册为自定义符号函数的方法的情况下, torch.onnx.export() 尝试内联对应于该 torch.autograd.Function 的图,使得 该函数被分解为函数内部使用的各个操作符。 只要这些单独的操作符被支持,导出就应该成功。例如:

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

此模型目前没有静态符号方法,但其导出方式如下:

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果你需要避免内联 torch.autograd.Function,你应该使用 operator_export_type 设置为 ONNX_FALLTHROUGHONNX_ATEN_FALLBACK 来导出模型。

自定义操作符

您可以使用自定义运算符导出您的模型,这些运算符包括许多标准 ONNX 运算符的组合,或者由自定义的 C++ 后端驱动。

ONNX-脚本函数

如果一个运算符不是标准的 ONNX 运算符,但可以由多个现有的 ONNX 运算符组成,你可以利用 ONNX-script 创建一个外部 ONNX 函数来支持该运算符。 你可以按照以下示例导出它:

import onnxscript
# 需要对齐三个操作集版本
# 这是 (1) ONNX 函数中的操作集版本
from onnxscript.onnx_opset import opset15 as op
opset_version = 15

x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)

@onnxscript.script(custom_opset)
def Selu(X):
    alpha = 1.67326  # 自动包装为常量
    gamma = 1.0507
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(X <= zero, neg, pos)

# setType API 为 ONNX 形状/类型推断提供形状/类型
def custom_selu(g: jit_utils.GraphContext, X):
    return g.onnxscript_op(Selu, X).setType(X.type())

# 注册自定义符号函数
# 需要对齐三个操作集版本
# 这是 (2) 注册表中的操作集版本
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::selu",
    symbolic_fn=custom_selu,
    opset_version=opset_version,
)

# 需要对齐三个操作集版本
# 这是 (2) 导出器中的操作集版本
torch.onnx.export(
    model,
    x,
    "model.onnx",
    opset_version=opset_version,
    # 仅在您想要指定操作集版本 > 1 时需要。
    custom_opsets={"onnx-script": 2}
)

上面的示例将其导出为“onnx-script”操作集中的自定义操作符。 在导出自定义操作符时,您可以在导出时使用custom_opsets字典指定自定义域版本。如果未指定,自定义操作集版本默认为1。

注意:请小心对齐上述示例中提到的opset版本,并确保在导出步骤中使用它们。 关于如何编写onnx-script函数的示例用法是onnx-script活跃开发中的测试版本。 请关注最新的ONNX-script

C++ 运算符

如果一个模型使用了在C++中实现的定制操作符,如在 使用定制C++操作符扩展TorchScript中所述, 你可以按照以下示例导出它:

from torch.onnx import symbolic_helper


# 定义自定义符号函数
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# 注册自定义符号函数
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # 调用自定义操作
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # 仅在需要指定 opset 版本 > 1 时需要
    custom_opsets={"custom_domain": 2}
)

上面的示例将其导出为“custom_domain”操作集中的自定义操作符。 在导出自定义操作符时,您可以在导出时使用custom_opsets字典指定自定义域版本。如果未指定,自定义操作集版本默认为1。

使用模型的运行时需要支持自定义操作。请参阅 Caffe2 自定义操作ONNX Runtime 自定义操作, 或您选择的运行时的文档。

一次性发现所有不可转换的ATen操作

当导出因无法转换的ATen操作失败时,实际上可能存在多个这样的操作,但错误消息仅提到第一个。要一次性发现所有无法转换的操作,您可以:

# 准备模型、参数、opset版本
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是近似的,因为在转换过程中可能会删除一些操作,这些操作不需要转换。其他一些操作可能只有部分支持,在特定输入下转换会失败,但这应该能让你大致了解哪些操作不受支持。请随时在GitHub上提出操作支持请求。

常见问题

问:我已经导出了我的LSTM模型,但它的输入大小似乎是固定的?

跟踪器记录示例输入的形状。如果模型应接受动态形状的输入,请在调用torch.onnx.export()时设置dynamic_axes

问:如何导出包含循环的模型?

问:如何导出带有基本类型输入(例如 int、float)的模型?

在 PyTorch 1.9 中添加了对原始数值类型输入的支持。 然而,导出器不支持带有字符串输入的模型。

问:ONNX 是否支持隐式标量数据类型转换?

ONNX 标准没有处理这部分,但导出器会尝试处理。 标量作为常量张量导出。 导出器会确定标量的正确数据类型。在极少数情况下,如果无法确定,您需要手动指定数据类型,例如 dtype=torch.float32。 如果您遇到任何错误,请 [创建一个 GitHub 问题](https://github.com/pytorch/pytorch/issues).

问:张量列表可以导出到ONNX吗?

是的,对于 opset_version >= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。

Python API

函数

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[源代码]

将模型导出为ONNX格式。

如果model既不是torch.jit.ScriptModule也不是 torch.jit.ScriptFunction,这将运行 model一次,以便将其转换为要导出的TorchScript图 (相当于torch.jit.trace())。因此,它对动态控制流的支持与torch.jit.trace()相同。

Parameters
  • 模型 (torch.nn.Module, torch.jit.ScriptModuletorch.jit.ScriptFunction) – 要导出的模型。

  • args (tupletorch.Tensor) –

    args 可以结构化为以下几种形式:

    1. 仅包含参数的元组:

      args = (x, y, z)
      

    元组应包含模型输入,使得 model(*args) 是模型的一个有效调用。任何非 Tensor 参数将被硬编码到导出的模型中;任何 Tensor 参数将成为导出模型的输入,按其在元组中出现的顺序排列。

    1. 一个 Tensor:

      args = torch.Tensor([1])
      

    这等同于包含该 Tensor 的 1 元组。

    1. 以字典形式结尾的参数元组:

      args = (
          x,
          {
              "y": input_y,
              "z": input_z
          }
      )
      

    元组中除最后一个元素外的所有元素将作为非关键字参数传递,命名参数将从最后一个元素中设置。如果字典中不存在命名参数,则将其分配默认值,如果没有提供默认值,则分配 None。

    注意

    如果字典是 args 元组的最后一个元素,它将被解释为包含命名参数。为了将字典作为最后一个非关键字参数传递,请在 args 元组的最后一个元素中提供一个空字典。例如,不要写:

    torch.onnx.export(
        model,
        (
            x,
            # 错误:将被解释为命名参数
            {y: z}
        ),
        "test.onnx.pb"
    )
    

    请写成:

    torch.onnx.export(
        model,
        (
            x,
            {y: z},
            {}
        ),
        "test.onnx.pb"
    )
    

  • f (联合[字符串, BytesIO]) – 一个类文件对象(使得 f.fileno() 返回一个文件描述符) 或者一个包含文件名的字符串。 一个二进制协议缓冲区将被写入到这个文件中。

  • export_params (bool, 默认 True) – 如果为 True,所有参数将被导出。如果你想导出一个未训练的模型,请将其设置为 False。在这种情况下,导出的模型将首先将其所有参数作为参数,顺序由 model.state_dict().values() 指定。

  • 详细 (布尔值, 默认 False) – 如果为 True,将打印正在导出的模型的描述到标准输出。此外,最终的 ONNX 图将包含从导出的模型中提取的 doc_string` 字段,该字段提到了 model 的源代码位置。如果为 True,将启用 ONNX 导出器日志记录。

  • 训练 (枚举, , 默认 TrainingMode.EVAL) –

    • TrainingMode.EVAL: 以推理模式导出模型。

    • TrainingMode.PRESERVE: 如果 model.training 为

      False,则以推理模式导出模型;如果 model.training 为 True,则以训练模式导出模型。

    • TrainingMode.TRAINING: 以训练模式导出模型。禁用可能干扰训练的优化。

      这可能会干扰训练。

  • input_names (列表 of 字符串, 默认空列表) – 按顺序分配给图的输入节点的名称。

  • output_names (列表 of 字符串, 默认空列表) – 按顺序分配给图的输出节点的名称。

  • operator_export_type (枚举, 默认 OperatorExportTypes.ONNX) –

    • OperatorExportTypes.ONNX: 将所有操作导出为常规的 ONNX 操作

      (在默认的操作集域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH: 尝试将所有操作转换

      为标准 ONNX 操作在默认的操作集域中。如果无法做到这一点 (例如,因为尚未添加将特定 torch 操作转换为 ONNX 的支持), 则回退到将操作导出到自定义操作集域而不进行转换。适用于 自定义操作 以及 ATen 操作。为了使导出的模型可用,运行时必须支持 这些非标准操作。

    • OperatorExportTypes.ONNX_ATEN: 所有 ATen 操作(在 TorchScript 命名空间“aten”中)

      都导出为 ATen 操作(在操作集域“org.pytorch.aten”中)。 ATen 是 PyTorch 的内置张量库,因此 这指示运行时使用 PyTorch 的实现来执行这些操作。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

      如果操作的实现中的数值差异导致 PyTorch 和 Caffe2 之间的行为差异很大(这在未训练的模型中更常见),这可能很有用。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK: 尝试将每个 ATen 操作

      (在 TorchScript 命名空间“aten”中)导出为常规 ONNX 操作。如果我们无法做到这一点 (例如,因为尚未添加将特定 torch 操作转换为 ONNX 的支持), 则回退到导出 ATen 操作。请参阅 OperatorExportTypes.ONNX_ATEN 的文档以获取上下文。 例如:

      graph(%0 : Float):
      %3 : int = prim::Constant[value=0]()
      # 不支持转换
      %4 : Float = aten::triu(%0, %3)
      # 支持转换
      %5 : Float = aten::mul(%4, %0)
      return (%5)
      

      假设 aten::triu 在 ONNX 中不受支持,这将导出为:

      graph(%0 : Float):
      %1 : Long() = onnx::Constant[value={0}]()
      # 未转换
      %2 : Float = aten::ATen[operator="triu"](%0, %1)
      # 已转换
      %3 : Float = onnx::Mul(%2, %0)
      return (%3)
      

      如果 PyTorch 是使用 Caffe2 构建的(即使用 BUILD_CAFFE2=1),那么 将启用 Caffe2 特定的行为,包括对由 量化 模块生成的操作的特殊支持。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_version (int, 默认值为17) – 目标的 默认 (ai.onnx) opset 版本。必须大于等于7且小于等于17。

  • do_constant_folding (bool, 默认 True) – 应用常量折叠优化。 常量折叠将替换一些所有输入均为常量的操作, 并用预先计算的常量节点代替。

  • dynamic_axes (dict[string, dict[int, string]] or dict[string, list(int)], default empty dict) –

    默认情况下,导出的模型将使所有输入和输出张量的形状完全匹配args中给出的形状。要指定张量的轴为动态的(即仅在运行时已知),请将dynamic_axes设置为具有以下模式的字典:

    • KEY (str): 输入或输出的名称。每个名称也必须在input_names

      output_names中提供。

    • VALUE (dict or list): 如果是字典,键是轴索引,值是轴名称。如果是

      列表,每个元素是一个轴索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"]
    )
    

    生成:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # 轴 0
            }
            dim {
              dim_value: 2  # 轴 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # 轴 0
    ...
    

    而:

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # 字典值:手动命名的轴
            "x": {0: "my_custom_axis_name"},
            # 列表值:自动名称
            "sum": [0],
        }
    )
    

    生成:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # 轴 0
            }
            dim {
              dim_value: 2  # 轴 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # 轴 0
    ...
    

  • keep_initializers_as_inputs (bool, 默认 None) –

    如果为 True,导出图中所有初始化器(通常对应于参数)也将作为图的输入添加。如果为 False,则初始化器不会作为图的输入添加,只有非参数输入会被添加为输入。

    这可能会允许后端/运行时进行更好的优化(例如常量折叠)。

    如果为 True,deduplicate_initializers 过程将不会执行。这意味着具有重复值的初始化器不会被去重,并将被视为图的不同输入。这允许在导出后在运行时提供不同的输入初始化器。

    如果 opset_version < 9,初始化器必须作为图的输入,此参数将被忽略,行为将等同于将此参数设置为 True。

    如果为 None,则行为将自动选择如下:

    • 如果 operator_export_type=OperatorExportTypes.ONNX,行为等同于

      将此参数设置为 False。

    • 否则,行为等同于将此参数设置为 True。

  • custom_opsets (dict[str, int], 默认空字典) –

    一个具有以下结构的dict:

    • KEY (str): opset域名

    • VALUE (int): opset版本

    如果自定义opset被model引用但在该字典中未提及, 则opset版本设置为1。仅应通过此参数指示自定义opset域名和版本。

  • export_modules_as_functions (boolsettypenn.Module, 默认 False) –

    启用标志 将所有 nn.Module 前向调用导出为 ONNX 中的局部函数。或者是一个集合,用于指示要导出为 ONNX 中局部函数的特定模块类型。 此功能需要 opset_version >= 15,否则导出将失败。这是因为 opset_version < 15 意味着 IR 版本 < 8,这意味着不支持局部函数。 模块变量将作为函数属性导出。函数属性有两类。

    1. 注解属性:通过 PEP 526 风格 进行类型注解的类变量将作为属性导出。 注解属性不会在 ONNX 局部函数子图中使用,因为它们不是由 PyTorch JIT 跟踪创建的,但它们可能被消费者用来确定是否用特定融合内核替换函数。

    2. 推断属性:在模块内部由操作符使用的变量。属性名称将带有前缀“inferred::”。这是为了与从 python 模块注解中检索的预定义属性区分开来。推断属性在 ONNX 局部函数子图中使用。

    • False(默认):将 nn.Module 前向调用导出为细粒度节点。

    • True:将所有 nn.Module 前向调用导出为局部函数节点。

    • nn.Module 类型的集合:将 nn.Module 前向调用导出为局部函数节点,

      仅当 nn.Module 的类型在集合中找到时。

  • autograd_inlining (布尔值, 默认 True) – 用于控制是否内联自动求导函数的标志。 详情请参阅 https://github.com/pytorch/pytorch/pull/74765

Raises
  • torch.onnx.errors.CheckerError – 如果ONNX检查器检测到无效的ONNX图。

  • torch.onnx.errors.UnsupportedOperatorError – 如果ONNX图无法导出,因为它使用了导出器不支持的运算符。

  • torch.onnx.errors.OnnxExporterError – 导出过程中可能发生的其他错误。 所有错误都是 errors.OnnxExporterError 的子类。

torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[源代码]

类似于 export(),但返回 ONNX 模型的文本表示。仅列出参数中的差异。所有其他参数与 export() 相同。

Parameters
  • add_node_names (布尔值, 默认值 True) – 是否设置 NodeProto.name。除非 google_printer=True,否则这没有区别。

  • google_printer (bool, 默认 False) – 如果为 False,将返回模型的自定义、紧凑表示。如果为 True,将返回 protobuf 的 Message::DebugString(),这会更详细。

Returns

一个包含ONNX模型人类可读表示的UTF-8字符串。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[源代码]

为自定义操作符注册一个符号函数。

当用户为自定义/贡献操作注册符号时,强烈建议通过setType API为该操作添加形状推断,否则在某些极端情况下,导出的图可能会有不正确的形状推断。setType的一个示例是test_aten_embedding_2,位于test_operators.py中。

请参阅模块文档中的“自定义操作符”以获取示例用法。

Parameters
  • symbolic_name (str) – 自定义操作符的名称,格式为“::”。

  • symbolic_fn (可调用对象) – 一个函数,它接收ONNX图和当前操作符的输入参数,并返回要添加到图中的新操作符节点。

  • opset_version (int) – 要注册的ONNX操作集版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[源代码]

取消注册 symbolic_name

参见模块文档中的“自定义操作符”以获取示例用法。

Parameters
  • symbolic_name (str) – 自定义操作符的名称,格式为“::”。

  • opset_version (int) – 要取消注册的ONNX操作集版本。

torch.onnx.select_model_mode_for_export(model, mode)[源代码]

一个上下文管理器,用于临时将 model 的训练模式设置为 mode,并在退出 with 块时重置它。

Parameters
  • 模型 – 与 model 参数的类型和含义相同,用于 export()

  • 模式 (训练模式) – 与 训练 参数的类型和含义相同,用于 导出()

torch.onnx.is_in_onnx_export()[源代码]

返回是否处于ONNX导出过程中。

Return type

bool

torch.onnx.enable_log()[源代码]

启用ONNX日志记录。

torch.onnx.disable_log()[源代码]

禁用 ONNX 日志记录。

torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[源代码]

查找原始模型与导出模型之间的所有不匹配项。

实验性。API可能会发生变化。

此工具帮助调试原始 PyTorch 模型与导出的 ONNX 模型之间的不匹配问题。它对模型图进行二分搜索,以找到展示不匹配的最小子图。

Parameters
Returns

包含不匹配信息的GraphInfo对象。

Return type

图信息

示例:

>>> import torch
>>> import torch.onnx.verification
>>> torch.manual_seed(0)
>>> opset_version = 15
>>> # 为 aten::relu 定义一个自定义符号函数。
>>> # 自定义符号函数是错误的,这将导致不匹配。
>>> def incorrect_relu_symbolic_function(g, self):
...     return self
>>> torch.onnx.register_custom_op_symbolic(
...     "aten::relu",
...     incorrect_relu_symbolic_function,
...     opset_version=opset_version,
... )
>>> class Model(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.layers = torch.nn.Sequential(
...             torch.nn.Linear(3, 4),
...             torch.nn.ReLU(),
...             torch.nn.Linear(4, 5),
...             torch.nn.ReLU(),
...             torch.nn.Linear(5, 6),
...         )
...     def forward(self, x):
...         return self.layers(x)
>>> graph_info = torch.onnx.verification.find_mismatch(
...     Model(),
...     (torch.randn(2, 3),),
...     opset_version=opset_version,
... )
===================== 图分区的不匹配信息: ======================
================================ 不匹配错误 ================================
张量不接近!
不匹配的元素:12 / 12 (100.0%)
最大绝对差异:0.2328854203224182 在索引 (1, 2) 处(允许的最大差异为 1e-07)
最大相对差异:0.699536174352349 在索引 (1, 3) 处(允许的最大差异为 0.001)
==================================== 树: =====================================
5 X   __2 X    __1 ✓
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 ✓
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 ✓
                           id: 111
=========================== 不匹配的叶子子图: ===========================
['01', '110']
============================= 不匹配的节点类型: =============================
{'aten::relu': 2}

JitScalarType

在 torch 中定义的标量类型。

torch.onnx.verification.GraphInfo

GraphInfo 包含 TorchScript 图及其转换后的 ONNX 图的验证信息。

torch.onnx.verification.VerificationOptions

ONNX 导出验证的选项。