torch.export 的源代码
import builtins
import copy
import dataclasses
import inspect
import io
import os
import sys
import typing
import warnings
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)
if TYPE_CHECKING:
# 在类型检查期间导入以下模块以启用代码智能功能,
# 不要无条件导入,因为它们导入了sympy,而导入sympy非常慢
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
__all__ = [
"Constraint",
"Dim",
"ExportBackwardSignature",
"ExportGraphSignature",
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
"dims",
"dynamic_dim",
"export",
"load",
"register_dataclass",
"save",
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
]
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
[docs]def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export` 接受一个任意的 Python 可调用对象(一个 nn.Module、一个函数或
一个方法)以及示例输入,并生成一个仅表示函数张量计算的前向图,以提前(AOT)方式
进行,随后可以使用不同的输入执行或序列化。 该前向图(1)生成功能性 ATen 运算符集中的
标准化运算符(以及任何用户指定的自定义运算符),(2)消除了所有 Python 控制流和数据结构
(有一定例外),并且(3)记录了为证明此标准化和控制流消除对未来输入有效的形状约束集。
**正确性保证**
在追踪过程中,:func:`export()` 会记录用户程序和底层 PyTorch 运算符内核所做的形状相关假设。
只有当这些假设成立时,输出的 :class:`ExportedProgram` 才被认为是有效的。
追踪会对输入张量的形状(而非值)做出假设。 这些假设必须在图捕获时进行验证,以使 :func:`export`
成功。 具体来说:
- 对输入张量静态形状的假设会自动验证,无需额外努力。
- 对输入张量动态形状的假设需要通过使用 :func:`Dim` API 构造动态维度,并通过 ``dynamic_shapes`` 参数
将它们与示例输入关联来显式指定。
如果任何假设无法验证,将引发致命错误。 当发生这种情况时,错误消息将包含建议的修复,以验证假设。
例如,:func:`export` 可能会建议对动态维度 ``dim0_x`` 的定义进行以下修复,该维度出现在与输入 ``x``
关联的形状中,之前定义为 ``Dim("dim0_x")``::
dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入 ``x`` 的第 0 维小于或等于 5 才有效。 您可以检查建议的动态维度定义
修复,然后将其逐字复制到代码中,而无需更改 :func:`export` 调用的 ``dynamic_shapes`` 参数。
参数:
mod: 我们将追踪此模块的前向方法。
args: 示例位置输入。
kwargs: 可选的示例关键字输入。
dynamic_shapes:
一个可选参数,类型应为:
1) 一个从 ``f`` 的参数名称到其动态形状规范的字典,
2) 一个元组,按原始顺序指定每个输入的动态形状规范。
如果您在关键字参数上指定动态性,则需要按原始函数签名中定义的顺序传递它们。
张量参数的动态形状可以指定为
(1)一个从动态维度索引到 :func:`Dim` 类型的字典,其中不需要在此字典中包含静态维度索引,但如果包含,
则应映射到 None;或(2)一个 :func:`Dim` 类型或 None 的元组 / 列表,其中 :func:`Dim` 类型对应于动态维度,
静态维度由 None 表示。 字典或张量的元组 / 列表参数通过使用映射或包含的规范序列递归指定。
strict: 启用时(默认),导出函数将通过 TorchDynamo 追踪程序,以确保生成的图的正确性。 否则,导出的程序
将不会验证图中所隐含的假设,并可能导致原始模型和导出模型之间的行为差异。 这在用户需要绕过追踪器中的错误,
或仅希望逐步启用模型安全性时很有用。 请注意,这不会影响生成的 IR 规范,无论此处传递的值如何,模型都将以相同方式序列化。
警告:此选项是实验性的,使用此选项需自行承担风险。
返回:
包含追踪的可调用对象的 :class:`ExportedProgram`。
**可接受的输入/输出类型**
可接受的输入类型(对于 ``args`` 和 ``kwargs``)和输出类型包括:
- 原始类型,即 ``torch.Tensor``、``int``、``float``、``bool`` 和 ``str``。
- 数据类,但必须先通过调用 :func:`register_dataclass` 进行注册。
- (嵌套)数据结构,包括 ``dict``、``list``、``tuple``、``namedtuple`` 和 ``OrderedDict``,包含上述所有类型。
"""
from ._trace import _export
<span