torch.fx.interpreter 的源代码
from .graph_module import GraphModule
from ._lazy_graph_module import _make_graph_module
from .graph import Graph
from .node import Argument, Node, Target, map_arg, map_aggregate
from .proxy import Proxy
from ._symbolic_trace import Tracer
from ._compatibility import compatibility
from . import config
import torch.fx.traceback as fx_traceback
import torch
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import inspect
from contextlib import contextmanager
from torch.hub import tqdm
__all__ = ['Interpreter', 'Transformer']
[docs]@compatibility(is_backward_compatible=True)
class Interpreter:
"""
一个解释器按节点执行 FX 图。这种模式可以用于许多事情,包括编写代码转换以及分析传递。
解释器类中的方法可以被重写以自定义执行行为。可重写方法的调用层次结构如下:
run()
+-- run_node
+-- placeholder()
+-- get_attr()
+-- call_function()
+-- call_method()
+-- call_module()
+-- output()
示例:
假设我们想将所有 ``torch.neg`` 的实例与 ``torch.sigmoid`` 互换(包括它们的 ``Tensor`` 方法等价物)。我们可以这样子类化 Interpreter:
class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数:
module (torch.nn.Module): 要执行的模块
garbage_collect_values (bool): 是否在模块执行后删除值。这确保了执行期间的内存使用最优。可以禁用此功能以检查执行中的所有中间值,例如通过查看 ``Interpreter.env`` 属性。
graph (Optional[Graph]): 如果传递了此参数,解释器将执行此图而不是 `module.graph`,使用提供的 `module` 参数来满足任何状态请求。
"""
@compatibility(is_backward_compatible=True)
def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
self.module = module
self.submodules = dict(self.module.named_modules())
if graph is not None:
self.graph = graph
else:
self.graph = self.module.graph
self.env : Dict[Node, Any] = {}
self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values
self.extra_traceback = True
if self.garbage_collect_values:
# 遍历反向节点并记录给定节点的首次使用实例。这表示程序执行顺序中节点的最后一次使用,我们将使用它来释放未使用的值
node_to_last_use : Dict[Node, Node] = {}
self.user_to_last_uses : Dict[Node, List[Node]] = {}
def register_last_uses(n : Node, user : Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
self.user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(self.graph.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
[docs] @compatibility(is_backward_compatible=True)
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -&