Shortcuts

torch.fx.proxy 的源代码

```html
# mypy: ignore-errors

import enum
import dis
import copy
import sys
import torch
import inspect
import operator
import traceback
import collections

from dataclasses import is_dataclass, fields


from .graph import magic_methods, reflectable_magic_methods, Graph
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
from .node import Target, Node, Argument, base_types, map_aggregate
from ._compatibility import compatibility
from .operator_schemas import check_for_mutable_operation
import torch.fx.traceback as fx_traceback

__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
           'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
           'ScopeContextManager']


@compatibility(is_backward_compatible=False)
class Scope:
    """ 记录模块路径和模块类型的Scope对象。
    Scope用于跟踪GraphModule中包含Node的模块的信息。例如::

        class Sub(torch.nn.Module):
            def forward(self, x):
                # 这将在GraphModule中成为一个call_method Node,
                # 其scope为(module_path="sub", module_type=Sub)
                return x.transpose(1, 2)

        class M(torch.nn.Module):
            def __init__(self):
                self.sub = Sub()

            def forward(self, x):
                # 这同样是一个call_method Node,
                # 其scope为(module_path="", None)
                x = x.transpose(1, 2)
                x = self.sub(x)
                return x

    """

    def __init__(self, module_path: str, module_type: Any):
        super().__init__()
        self.module_path = module_path
        self.module_type = module_type


@compatibility(is_backward_compatible=False)
class ScopeContextManager:
    """ 用于跟踪符号追踪期间Node的Scope的上下文管理器。
    当进入Module的forward函数时,我们将更新当前模块的scope信息,
    当我们退出时,我们将恢复之前的scope信息。
    """

    def __init__(
        self,
        scope: Scope,
        current_scope: Scope,
    ):
        super().__init__()
        # 保留prev scope的副本以便在退出时恢复
        self._prev_scope = copy.copy(scope)
        # 更新scope为当前scope
        scope.module_path = current_scope.module_path
        scope.module_type = current_scope.module_type
        # 保存引用以便我们可以恢复它
        self._scope = scope

    def __enter__(self):
        return self._scope

    def __exit__(self, *args):
        self._scope.module_path = self._prev_scope.module_path
        self._scope.module_type = self._prev_scope.module_type
        return


_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"]


@compatibility(is_backward_compatible=True)
class TracerBase:
    graph: Graph
    record_stack_traces : bool = False
    # 可变模式检查的功能标志
    # 默认在1.12中启用
    check_mutable_operations : bool = False
    # 断言追踪的功能标志
    trace_asserts : bool = False
    # 代理访问缓冲区值的功能标志
    proxy_buffer_attributes : bool = False

    # 要追踪的函数的名称。当``root``是``nn.Module``的实例时,它将仅被使用
    traced_func_name: str = "forward"

    # 将包含模块的名称映射到操作符名称
    scope : Scope

    # 记录模块调用栈
    module_stack: OrderedDict[str, Tuple[str, Any]]

    # 节点名称到模块scope的映射
    node_name_to_scope: Dict[str, Tuple[str, type]]

    @compatibility(is_backward_compatible=True)
    def create_node(self, kind : str, target : Target,
                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
优云智算