在ATen IR上编写图形转换¶
传递¶
由于 ATen IR 位于 FX Graph/GraphModule 级别,任何为 FX Graphs 编写的转换都可以轻松应用于 ATen IR。如果你熟悉编写 FX 图转换,那么这将是一样的。
编写转换的最直接方式是通过遍历给定的图并直接操作图中的节点。
例如,假设我们想要将
torch.ops.aten.add.Tensor() 调用替换为
torch.ops.aten.mul.Tensor() 调用:
import torch
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensor
我们也可以通过FX实用函数来删除和追加新节点,这些函数可以在
Graph
文档中找到。例如,如果我们想在add调用之后插入一个
torch.ops.aten.relu.default():
import torch
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
# 指定插入点。在此范围内添加到图中的任何节点都将插入到 `node` 之后
with gm.graph.inserting_after(node):
# 插入一个新的 `call_function` 节点,操作符为 `torch.ops.aten.relu.default`
new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
# 将所有使用 `node` 的地方替换为使用 `new_relu_node`
node.replace_all_uses_with(new_relu_node)
一般来说,变换可以大致分为几个轴:
轴 A: 1. 创建一对一映射(例如分解) 2. 创建多对一映射(例如融合)
轴 B: 1. 进行正向迭代(例如形状传播)2. 进行反向迭代(例如死代码消除)
轴 C:1. 依赖于本地节点信息(例如,外变量转换)2. 依赖于全局图信息(例如,内存规划)
我们对这些用例频率的预测是:1. A.1, B.1, C.1 2. A.2 3. B.2, C.2
尽管我们可以通过直接操作图来实现所有的图转换,但我们还提供了一些辅助工具,以便于第1级和第2级用例的使用。
变压器¶
对于第1级用例(创建一对多映射、进行正向迭代以及查看本地节点信息),我们可以利用 Transformer 类来执行每个节点并重新创建一个图,除了指定的转换之外。
一对一传递¶
一对一映射的一个示例是,如果我们想要将操作 A 替换为另一个操作 B,我们可以运行 GraphModule,并且每次看到操作 A 时,返回操作 B。
一个例子是:
class ReplaceAddWithMul(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
调用 super().call_function(target, args, kwargs, meta) 会创建一个
call_function FX 节点,并返回使用给定参数运行操作符的结果。
一对一传递¶
如果我们想要进行一对多的映射,比如将操作A替换为另外两个操作B和C,我们将调用两次super().call_function来创建两个FX节点,一个使用操作B,另一个使用操作C,并返回运行操作C的结果。
例如:
class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
原始:
def f(x, y):
return x + y
传递后:
def f(x, y):
z = x * y
return z - y
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
x, y = args
mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
一对零传递¶
如果我们想要移除一个操作,我们可以直接返回传入函数的值:
class RemoveDetachPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_function(target, args, kwargs, meta)
assert len(args) == 1
return args[0]
transformed_graph_module = RemoveDetachPass(graph_module).transform()
利用本地信息¶
利用局部节点信息的一个例子是,如果我们想将图中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule,并且对于每个包含标量的参数,我们将其转换为张量。它可能看起来像这样:
def args_map(target, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# 根据传入的函数更新参数
def update(key, args, schema):
args[key] = fn(args[key], schema)
# 更新模式中的每个参数
for i, schema in enumerate(target._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
return tuple(args), kwargs
class ScalarToTensorPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
breakpoint()
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(target, try_coerce, args, kwargs)
return super().call_function(target, args, kwargs)
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
子图重写器¶
为了创建多对一的映射,我们可以利用FX的子图重写器。
给定一个模式,它会创建一个与该模式匹配的运算符子图,然后使用替换替换每个匹配的子图。
注意:
这是一个原地操作。
The pattern 和 replacement 输入必须是可调用的函数或包含与图中所用操作符相同的GraphModules(ATen操作符),以便子图重写器能够在图中找到正确的模式。模式/替换可调用对象的输入在匹配时将被视为通配符。
一个例子:
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
子图重写器返回一个 ReplacedPatterns 列表:
@dataclass
class ReplacedPatterns:
# 匹配找到的节点
anchor: Node
# 将模式子图中的节点映射到较大图中的节点
nodes_map: Dict[Node, Node]
# 添加到图中的节点列表
replacements: List[Node]
注意:
子图重写器创建的节点将不会具有在匹配节点中填充的元数据,但您可以使用
`ReplacedPatterns.nodes_map` 在原始图中找到匹配的节点,并使用 `ReplacedPatterns.replacements` 找到在转换图中被替换的节点。
传递管理器¶
`PassManager <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__
是一个用于在给定图模块上运行多个传递的类。在初始化一个
PassManager 实例时,我们传入一个我们想要运行的传递列表,并设置几个标志。为了在图模块上运行这些传递,我们可以直接将图模块传递给
PassManager 实例。
一个例子:
from torch.fx.passes.infra.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要添加一组在每次传递后运行的通用检查,我们可以调用函数 set_checks(check: Callable),该函数接受一个可调用函数作为输入。如果设置了 run_checks_after_each_pass 标志,则在图模块的每次传递运行后将调用 check。
一个例子:
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("目标应该是除法!")
pm.add_checks(check_div_target)
pm(graph_module) # 在replace_div_with_mul pass之后引发ValueError
分区器¶
有几种常见的基于FX图的分区器,我们可以用来对图进行分区。
子图匹配器¶
为了在一个图中找到匹配特定模式的子图,我们可以利用 FX 的
`SubgraphMatcher <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__。
类属性:
pattern (Graph): 目标匹配模式。图中的占位符节点在匹配时将被视为通配符。match_output (bool): 如果为True,模式图中的输出节点将被视为目标模式的一部分。如果为False,输出节点在匹配过程中将被忽略。match_placeholder (bool): 如果为True,模式图中的占位符节点将被视为目标模式的一部分。如果为False,占位符节点将被用作通配符。remove_overlapping_matches (bool): 如果为True,在出现重叠匹配的情况下,只会返回第一个匹配项。ignore_literals (bool): 如果为True,将不会检查字面量是否相等,而是将它们视为通配符。
一个例子:
The match 函数返回一个 InternalMatch 列表:
@dataclass
class InternalMatch():
# 找到匹配的节点
anchors: List[Node]
# 将模式子图中的节点映射到较大图中的节点
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# 目标图中匹配模式中占位符的节点
placeholder_nodes: List[Node] = field(default_factory=list)
# 匹配子图中由输出返回的节点
returning_nodes: List[Node] = field(default_factory=list)
基于能力的分区器¶
要找到支持特定不变量的最大子图,我们可以利用 FX 的
`CapabilityBasedPartitioner <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__。
类属性
graph_module (torch.fx.GraphModule): 我们正在分区的图模块。operator_support (OperatorSupportBase): 用于确定图中的节点是否在分区中受支持的对象。allows_single_node_partition (bool): 如果为True,允许形成单节点分区。non_compute_ops (Optional[Sequence[str]]): 一组被视为“非计算”的操作(例如torch.ops.aten.view和_operator.getitem),以便分区器不会创建仅包含这些非计算操作的图allowed_single_node_partition_ops (Optional[Sequence[str]]): 一组允许出现在单节点分区中的操作。
The
`OperatorSupportBase <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>`__
类被分区器用来确定图中的特定节点是否属于分区。这是通过重写
is_node_supported 函数来实现的。你可以通过使用
`chain <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>`__(如果任何 OperatorSupportBase 返回 False,则返回 False) 和
`any_chain <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>`__
(如果任何 OperatorSupportBase 返回 True,则返回 True) 来链接多个
OperatorSupportBase。
一个例子:
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# 返回一个分区列表(每个分区包含的节点列表)
partition_list = capability_partitioner.propose_partitions()
# 将分区融合为图模块,并在图中插入`call_module`节点
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)