torch.fx.subgraph_rewriter 的源代码
```html
from .graph_module import GraphModule from .graph import Graph from .node import Node from ._symbolic_trace import symbolic_trace from ._compatibility import compatibility import copy from dataclasses import dataclass from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING import torch if TYPE_CHECKING: from .passes.utils.matcher_with_name_node_map_utils import InternalMatch __all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] @compatibility(is_backward_compatible=True) class Match(NamedTuple): # 从哪个节点找到的匹配 anchor: Node # 将模式子图中的节点映射到较大图中的节点 nodes_map: Dict[Node, Node] @compatibility(is_backward_compatible=False) @dataclass class ReplacedPatterns: # 从哪个节点找到的匹配 anchor: Node # 将模式子图中的节点映射到较大图中的节点 nodes_map: Dict[Node, Node] # 添加到图中的节点列表 replacements: List[Node] def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: gm.delete_all_unused_submodules() if isinstance(replacement, GraphModule): replacement.graph.lint() def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: module_path, _, attr_name = target.rpartition(".") try: mod: torch.nn.Module = gm.get_submodule(module_path) except AttributeError: return None attr = getattr(mod, attr_name, None) return attr for node in gm.graph.nodes: if node.op == "call_module" or node.op == "get_attr": gm_attr = try_get_attr(gm, node.target) replacement_attr = try_get_attr(replacement, node.target) # 情况1:此目标已作为属性存在于我们的结果GraphModule中。无论它是否存在于`replacement`中,现有的子模块都具有优先权。 if gm_attr is not None: continue # 情况2:目标仅作为属性存在于`replacement`中,因此我们需要将其复制过来。 elif replacement_attr is not None: new_attr = copy.deepcopy(replacement_attr) if isinstance(replacement_attr, torch.nn.Module): gm.add_submodule(node.target, new_attr) else: setattr(gm, node.target, new_attr) # 情况3:目标既不存在于`gm`中,也不存在于`replacement`中 else: raise RuntimeError("尝试在子图重写期间创建一个", node.op, "节点,但引用的属性在替换的GraphModule中不存在") gm.graph.lint()[docs]@compatibility(is_backward_compatible=True) def replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule] ) -> List[Match]: """ 在GraphModule的Graph中匹配所有可能的非重叠操作集及其数据依赖关系(``pattern``),然后将其替换为另一个子图(``replacement``)。 参数: ``gm``: 包装要操作的Graph的GraphModule ``pattern``: 要在``gm``中匹配以进行替换的子图 ``replacement``: 用于替换``pattern``的子图 返回: List[Match]: 一个``Match``对象列表,表示``pattern``在原始图中的匹配位置。如果没有匹配,列表为空。``Match``定义为: .. code-block:: python class Match(NamedTuple): # 从哪个节点找到的匹配 anchor: Node # 将模式子图中的节点映射到较大图中的节点 nodes_map: Dict[Node, Node] 示例: