Shortcuts

torch.fx.interpreter 的源代码

from .graph_module import GraphModule
from ._lazy_graph_module import _make_graph_module
from .graph import Graph
from .node import Argument, Node, Target, map_arg, map_aggregate
from .proxy import Proxy
from ._symbolic_trace import Tracer
from ._compatibility import compatibility
from . import config
import torch.fx.traceback as fx_traceback
import torch
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import inspect
from contextlib import contextmanager
from torch.hub import tqdm

__all__ = ['Interpreter', 'Transformer']

[docs]@compatibility(is_backward_compatible=True) class Interpreter: """ 一个解释器按节点执行 FX 图。这种模式可以用于许多事情,包括编写代码转换以及分析传递。 解释器类中的方法可以被重写以自定义执行行为。可重写方法的调用层次结构如下: run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output() 示例: 假设我们想将所有 ``torch.neg`` 的实例与 ``torch.sigmoid`` 互换(包括它们的 ``Tensor`` 方法等价物)。我们可以这样子类化 Interpreter: class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid()) 参数: module (torch.nn.Module): 要执行的模块 garbage_collect_values (bool): 是否在模块执行后删除值。这确保了执行期间的内存使用最优。可以禁用此功能以检查执行中的所有中间值,例如通过查看 ``Interpreter.env`` 属性。 graph (Optional[Graph]): 如果传递了此参数,解释器将执行此图而不是 `module.graph`,使用提供的 `module` 参数来满足任何状态请求。 """ @compatibility(is_backward_compatible=True) def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): self.module = module self.submodules = dict(self.module.named_modules()) if graph is not None: self.graph = graph else: self.graph = self.module.graph self.env : Dict[Node, Any] = {} self.name = "Interpreter" self.garbage_collect_values = garbage_collect_values self.extra_traceback = True if self.garbage_collect_values: # 遍历反向节点并记录给定节点的首次使用实例。这表示程序执行顺序中节点的最后一次使用,我们将使用它来释放未使用的值 node_to_last_use : Dict[Node, Node] = {} self.user_to_last_uses : Dict[Node, List[Node]] = {} def register_last_uses(n : Node, user : Node): if n not in node_to_last_use: node_to_last_use[n] = user self.user_to_last_uses.setdefault(user, []).append(n) for node in reversed(self.graph.nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
[docs] @compatibility(is_backward_compatible=True) def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -&
优云智算