Shortcuts

torch.export 的源代码

import builtins
import copy
import dataclasses
import inspect
import io
import os
import sys
import typing
import warnings
from enum import auto, Enum
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

from torch.utils._pytree import (
    FlattenFunc,
    FromDumpableContextFn,
    ToDumpableContextFn,
    UnflattenFunc,
)

if TYPE_CHECKING:
    # 在类型检查期间导入以下模块以启用代码智能功能,
    # 不要无条件导入,因为它们导入了sympy,而导入sympy非常慢
    from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint


__all__ = [
    "Constraint",
    "Dim",
    "ExportBackwardSignature",
    "ExportGraphSignature",
    "ExportedProgram",
    "ModuleCallEntry",
    "ModuleCallSignature",
    "dims",
    "dynamic_dim",
    "export",
    "load",
    "register_dataclass",
    "save",
    "unflatten",
    "FlatArgsAdapter",
    "UnflattenedModule",
]


from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule


PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]


[docs]def export( mod: torch.nn.Module, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, strict: bool = True, preserve_module_call_signature: Tuple[str, ...] = (), ) -> ExportedProgram: """ :func:`export` 接受一个任意的 Python 可调用对象(一个 nn.Module、一个函数或 一个方法)以及示例输入,并生成一个仅表示函数张量计算的前向图,以提前(AOT)方式 进行,随后可以使用不同的输入执行或序列化。 该前向图(1)生成功能性 ATen 运算符集中的 标准化运算符(以及任何用户指定的自定义运算符),(2)消除了所有 Python 控制流和数据结构 (有一定例外),并且(3)记录了为证明此标准化和控制流消除对未来输入有效的形状约束集。 **正确性保证** 在追踪过程中,:func:`export()` 会记录用户程序和底层 PyTorch 运算符内核所做的形状相关假设。 只有当这些假设成立时,输出的 :class:`ExportedProgram` 才被认为是有效的。 追踪会对输入张量的形状(而非值)做出假设。 这些假设必须在图捕获时进行验证,以使 :func:`export` 成功。 具体来说: - 对输入张量静态形状的假设会自动验证,无需额外努力。 - 对输入张量动态形状的假设需要通过使用 :func:`Dim` API 构造动态维度,并通过 ``dynamic_shapes`` 参数 将它们与示例输入关联来显式指定。 如果任何假设无法验证,将引发致命错误。 当发生这种情况时,错误消息将包含建议的修复,以验证假设。 例如,:func:`export` 可能会建议对动态维度 ``dim0_x`` 的定义进行以下修复,该维度出现在与输入 ``x`` 关联的形状中,之前定义为 ``Dim("dim0_x")``:: dim = Dim("dim0_x", max=5) 此示例意味着生成的代码要求输入 ``x`` 的第 0 维小于或等于 5 才有效。 您可以检查建议的动态维度定义 修复,然后将其逐字复制到代码中,而无需更改 :func:`export` 调用的 ``dynamic_shapes`` 参数。 参数: mod: 我们将追踪此模块的前向方法。 args: 示例位置输入。 kwargs: 可选的示例关键字输入。 dynamic_shapes: 一个可选参数,类型应为: 1) 一个从 ``f`` 的参数名称到其动态形状规范的字典, 2) 一个元组,按原始顺序指定每个输入的动态形状规范。 如果您在关键字参数上指定动态性,则需要按原始函数签名中定义的顺序传递它们。 张量参数的动态形状可以指定为 (1)一个从动态维度索引到 :func:`Dim` 类型的字典,其中不需要在此字典中包含静态维度索引,但如果包含, 则应映射到 None;或(2)一个 :func:`Dim` 类型或 None 的元组 / 列表,其中 :func:`Dim` 类型对应于动态维度, 静态维度由 None 表示。 字典或张量的元组 / 列表参数通过使用映射或包含的规范序列递归指定。 strict: 启用时(默认),导出函数将通过 TorchDynamo 追踪程序,以确保生成的图的正确性。 否则,导出的程序 将不会验证图中所隐含的假设,并可能导致原始模型和导出模型之间的行为差异。 这在用户需要绕过追踪器中的错误, 或仅希望逐步启用模型安全性时很有用。 请注意,这不会影响生成的 IR 规范,无论此处传递的值如何,模型都将以相同方式序列化。 警告:此选项是实验性的,使用此选项需自行承担风险。 返回: 包含追踪的可调用对象的 :class:`ExportedProgram`。 **可接受的输入/输出类型** 可接受的输入类型(对于 ``args`` 和 ``kwargs``)和输出类型包括: - 原始类型,即 ``torch.Tensor``、``int``、``float``、``bool`` 和 ``str``。 - 数据类,但必须先通过调用 :func:`register_dataclass` 进行注册。 - (嵌套)数据结构,包括 ``dict``、``list``、``tuple``、``namedtuple`` 和 ``OrderedDict``,包含上述所有类型。 """ from ._trace import _export <span