Shortcuts

torch.fx.subgraph_rewriter 的源代码

```html
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility

import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
import torch

if TYPE_CHECKING:
    from .passes.utils.matcher_with_name_node_map_utils import InternalMatch

__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]

@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
    # 从哪个节点找到的匹配
    anchor: Node
    # 将模式子图中的节点映射到较大图中的节点
    nodes_map: Dict[Node, Node]

@compatibility(is_backward_compatible=False)
@dataclass
class ReplacedPatterns:
    # 从哪个节点找到的匹配
    anchor: Node
    # 将模式子图中的节点映射到较大图中的节点
    nodes_map: Dict[Node, Node]
    # 添加到图中的节点列表
    replacements: List[Node]

def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
    gm.delete_all_unused_submodules()

    if isinstance(replacement, GraphModule):
        replacement.graph.lint()

    def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
        module_path, _, attr_name = target.rpartition(".")
        try:
            mod: torch.nn.Module = gm.get_submodule(module_path)
        except AttributeError:
            return None
        attr = getattr(mod, attr_name, None)
        return attr

    for node in gm.graph.nodes:
        if node.op == "call_module" or node.op == "get_attr":

            gm_attr = try_get_attr(gm, node.target)
            replacement_attr = try_get_attr(replacement, node.target)

            # 情况1:此目标已作为属性存在于我们的结果GraphModule中。无论它是否存在于`replacement`中,现有的子模块都具有优先权。
            if gm_attr is not None:
                continue

            # 情况2:目标仅作为属性存在于`replacement`中,因此我们需要将其复制过来。
            elif replacement_attr is not None:
                new_attr = copy.deepcopy(replacement_attr)
                if isinstance(replacement_attr, torch.nn.Module):
                    gm.add_submodule(node.target, new_attr)
                else:
                    setattr(gm, node.target, new_attr)

            # 情况3:目标既不存在于`gm`中,也不存在于`replacement`中
            else:
                raise RuntimeError("尝试在子图重写期间创建一个", node.op,
                                   "节点,但引用的属性在替换的GraphModule中不存在")

    gm.graph.lint()


[docs]@compatibility(is_backward_compatible=True) def replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule] ) -> List[Match]: """ 在GraphModule的Graph中匹配所有可能的非重叠操作集及其数据依赖关系(``pattern``),然后将其替换为另一个子图(``replacement``)。 参数: ``gm``: 包装要操作的Graph的GraphModule ``pattern``: 要在``gm``中匹配以进行替换的子图 ``replacement``: 用于替换``pattern``的子图 返回: List[Match]: 一个``Match``对象列表,表示``pattern``在原始图中的匹配位置。如果没有匹配,列表为空。``Match``定义为: .. code-block:: python class Match(NamedTuple): # 从哪个节点找到的匹配 anchor: Node # 将模式子图中的节点映射到较大图中的节点 nodes_map: Dict[Node, Node] 示例:
优云智算