Shortcuts

torch.export.exported_program 的源代码

```html
import copy
import dataclasses
import functools
import types
import warnings
from collections import namedtuple
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

from torch.fx.immutable_collections import immutable_dict, immutable_list

if TYPE_CHECKING:
    # 在类型检查期间导入以下模块以启用代码智能功能,
    # 例如在工具(如pylance)中的自动补全,即使这些模块在用户代码中没有显式导入。

    import sympy

    from torch.utils._sympy.value_ranges import ValueRanges

import torch
import torch.utils._pytree as pytree
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

from .graph_signature import (  # noqa: F401
    _sig_to_specs,
    ArgumentSpec,
    ConstantArgument,
    CustomObjArgument,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    SymIntArgument,
    TensorArgument,
)


__all__ = [
    "ExportedProgram",
    "ModuleCallEntry",
    "ModuleCallSignature",
]


PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]


[docs]@dataclasses.dataclass class ModuleCallSignature: inputs: List[ArgumentSpec] outputs: List[ArgumentSpec] in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec
[docs]@dataclasses.dataclass class ModuleCallEntry: fqn: str signature: Optional[ModuleCallSignature] = None
def _disable_prexisiting_fake_mode(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): with maybe_disable_fake_tensor_mode(): return fn(*args, **kwargs) return wrapper def _fx_collection_equivalence_fn( spec1_type: Optional[type], spec1_context: pytree.Context, spec2_type: Optional[type], spec2_context: pytree.Context, ) -> bool: """将容器及其不可变变体视为相同类型。否则 正常比较。 """ if spec1_type is None or spec2_type is None: return spec1_type is spec2_type and spec1_context == spec2_context if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( spec2_type, (dict, immutable_dict) ): return spec1_context == spec2_context if issubclass(spec1_type, (list, immutable_list)) and issubclass( spec2_type, (list, immutable_list) ): return spec1_context == spec2_context return spec1_type is spec2_type and spec1_context == spec2_context
[docs]class ExportedProgram: """ :func:`export` 导出的程序包。它包含 一个表示张量计算的 :class:`torch.fx.Graph`,一个包含所有提升参数和缓冲区的张量值的 state_dict,以及各种元数据。 你可以像原始的可调用对象一样调用 ExportedProgram,使用与 :func:`export` 跟踪相同的调用约定。 要对图进行转换,请使用 ``.module`` 属性访问 一个 :class:`torch.fx.GraphModule`。然后你可以使用 `FX 转换 `_ 重写图。之后,你可以简单地使用 :func:`export` 再次构造一个正确的 ExportedProgram。 """ def __init__( self, root: Union[torch.nn.Module, Dict[str, Any]], graph:</span