torch.fx¶
概述¶
FX 是一个供开发者使用的工具包,用于转换 nn.Module
实例。FX 由三个主要组件组成:一个 符号追踪器,
一个 中间表示,以及 Python 代码生成。以下是这些组件的实际演示:
import torch
# 简单的模块示例
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# 符号追踪前端 - 捕获模块的语义
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# 高级中间表示(IR)- 图表示
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# 代码生成 - 有效的Python代码
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器执行 Python 代码的“符号执行”。它通过代码传递称为代理的假值。这些代理的操作会被记录下来。有关符号追踪的更多信息,可以在 symbolic_trace() 和 Tracer 文档中找到。
中间表示(intermediate representation)是用于存储在符号追踪过程中记录的操作的容器。它由一系列节点组成,这些节点表示函数输入、调用点(对函数、方法或torch.nn.Module实例的调用)以及返回值。有关IR的更多信息,可以在Graph的文档中找到。IR是应用转换的格式。
Python代码生成是使FX成为Python到Python(或模块到模块)转换工具包的原因。对于每个Graph IR,我们可以创建与Graph语义匹配的有效Python代码。此功能被封装在GraphModule中,它是一个torch.nn.Module实例,持有Graph以及从Graph生成的forward方法。
总的来说,这一系列组件(符号追踪 -> 中间表示 -> 转换 -> Python代码生成)构成了FX的Python到Python转换管道。此外,这些组件可以单独使用。例如,符号追踪可以单独用于捕获代码的形式以进行分析(而非转换)目的。代码生成可以用于以编程方式生成模型,例如从配置文件生成。FX有很多用途!
可以在 示例 仓库中找到几个示例转换。
编写转换¶
什么是FX变换?本质上,它是一个看起来像这样的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# 步骤1:获取表示 `m` 中代码的图
# 注意:torch.fx.symbolic_trace 是调用 fx.Tracer.trace 并构造 GraphModule 的包装器。我们将在我们的变换中将其拆分,以允许调用者自定义跟踪行为。
graph : torch.fx.Graph = tracer_class().trace(m)
# 步骤2:修改此图或创建一个新图
graph = ...
# 步骤3:构造一个要返回的模块
return torch.fx.GraphModule(m, graph)
您的转换将接收一个torch.nn.Module,从中获取一个Graph,进行一些修改,并返回一个新的torch.nn.Module。您应该将您的FX转换返回的torch.nn.Module视为与常规torch.nn.Module相同——您可以将其传递给另一个FX转换,可以将其传递给TorchScript,或者可以运行它。确保您的FX转换的输入和输出是一个torch.nn.Module将允许组合性。
注意
也可以修改现有的 GraphModule,而不是创建一个新的,如下所示:
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# 修改 gm.graph
# <...>
# 从其 Graph 重新编译 `gm` 的 forward() 方法
gm.recompile()
return gm
请注意,您必须调用 GraphModule.recompile() 以使生成的
forward() 方法与修改后的 Graph 同步。
假设你已经传入了一个已经被追踪到Graph的torch.nn.Module,现在有两种主要的方法可以用来构建一个新的Graph。
图的快速入门¶
图的语义的完整处理可以在Graph文档中找到,但我们在这里将介绍基础知识。一个Graph是一个表示GraphModule上的方法的数据结构。这需要的信息是:
方法的输入是什么?
方法内部运行了哪些操作?
方法的输出(即返回)值是什么?
这三个概念都用 Node 实例表示。
让我们通过一个简短的例子来看看这意味着什么:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
在这里,我们定义了一个模块 MyModule 用于演示目的,实例化它,
符号化地追踪它,然后调用 Graph.print_tabular() 方法来打印
出一个表格,显示这个 Graph 的节点:
操作码
名称
目标
参数
关键字参数
占位符
x
x
()
{}
获取属性
线性权重
线性.权重
()
{}
调用函数
add_1
<内置函数 add>
(x, linear_weight)
{}
调用模块
线性_1
线性
(add_1,)
{}
调用方法
relu_1
ReLU
(linear_1,)
{}
调用函数
sum_1
<内置方法 sum …>
(relu_1,)
{‘dim’: -1}
调用函数
topk_1
<内置方法 topk …>
(sum_1, 3)
{}
输出
输出
输出
(topk_1,)
{}
我们可以使用这些信息来回答我们上面提出的问题。
方法的输入是什么?在FX中,方法输入是通过特殊的
placeholder节点指定的。在这种情况下,我们有一个placeholder节点,其target为x,这意味着我们有一个名为x的单一(非self)参数。方法中的操作是什么?
get_attr、call_function、call_module和call_method节点 表示方法中的操作。所有这些的语义的完整处理可以在Node文档中找到。方法的返回值是什么?在
Graph中,返回值由一个特殊的output节点指定。
鉴于我们现在了解了代码在FX中的基本表示方式,我们可以探讨如何编辑一个Graph。
图操作¶
直接图操作¶
构建这个新的Graph的一种方法是直接操作旧的图。为了帮助实现这一点,我们可以简单地获取从符号追踪中得到的Graph并对其进行修改。例如,假设我们希望将torch.add()调用替换为torch.mul()调用。
import torch
import torch.fx
# 示例模块
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX 将其 Graph 表示为一个有序的节点列表,
# 因此我们可以遍历它们。
for node in graph.nodes:
# 检查我们是否在调用一个函数(例如:
# torch.add)
if node.op == 'call_function':
# target 属性是 call_function 调用的函数。
if node.target == torch.add:
node.target = torch.mul
graph.lint() # 进行一些检查以确保
# Graph 是良构的。
return fx.GraphModule(m, graph)
我们还可以进行更复杂的Graph重写,例如删除或追加节点。为了帮助这些转换,FX提供了一些用于转换图的实用函数,可以在Graph文档中找到。下面是一个使用这些API追加torch.relu()调用的示例。
# 指定插入点。在此范围内添加到
# Graph中的任何节点都将插入到`node`之后
with traced.graph.inserting_after(node):
# 插入一个新的`call_function`节点,调用`torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# 我们希望所有使用`node`值的地方
# 现在使用我们在`relu`调用后添加的值。
# 我们使用`replace_all_uses_with` API来实现这一点。
node.replace_all_uses_with(new_node)
对于仅包含替换的简单变换,您也可以使用子图重写器。
使用 replace_pattern() 进行子图重写¶
FX 还提供了在直接图形操作之上的另一层自动化。
replace_pattern() API 本质上是一个用于编辑
Graph 的“查找/替换”工具。它允许你指定一个 pattern 和 replacement 函数
并且它会跟踪这些函数,找到 pattern 图形中操作组的实例,并用 replacement 图形的副本来替换这些实例。这可以帮助大大自动化繁琐的图形操作代码,当转换变得更加复杂时,这些代码可能会变得难以管理。
代理/重追溯¶
另一种操作 Graph 的方法是重用符号追踪中使用的 Proxy 机制。例如,假设我们想要编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是执行必要的图重写,在 F.relu 之后插入比较和乘法,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象来自动记录操作到 Graph 中来自动化这个过程。
要使用此方法,我们将要插入的操作编写为常规的 PyTorch 代码,并使用 Proxy 对象作为参数调用该代码。
这些 Proxy 对象将捕获对其执行的操作,并将它们附加到 Graph 中。
# 注意,这个分解规则可以像普通的Python一样阅读
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
将`model`分解为更小的组成操作。
目前,这仅支持将ReLU分解为其
数学定义:(x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# 通过用代理包装参数,
# 我们可以分派到适当的
# 分解规则,并通过符号化跟踪将其隐式添加到图表中。
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# 对`Proxy`的操作总是产生新的`Proxy`,并且
# 我们的分解规则的返回值也不例外。
# 我们需要从`Proxy`中提取底层的`Node`
# 以便在后续的转换迭代中使用它。
new_node = output_proxy.node
env[node.name] = new_node
else:
# 默认情况:我们没有这个节点的分解规则,
# 所以只需将节点复制到新图中。
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图操作外,使用Proxy还允许您将重写规则指定为原生Python代码。对于需要大量重写规则的转换(如vmap或grad),这通常可以提高规则的可读性和可维护性。请注意,虽然调用Proxy时我们也传递了一个指向底层变量graph的跟踪器。这样做是为了防止如果图中的操作是n元操作(例如,add是一个二元运算符),调用Proxy不会创建多个图跟踪器实例,这可能导致意外的运行时错误。我们建议使用这种方法,特别是当底层操作不能安全地假设为单一时。
解释器模式¶
在FX中,一个有用的代码组织模式是遍历一个Graph中的所有Node并执行它们。这可以用于多种用途,包括对流经图的值进行运行时分析或通过使用Proxy进行重跟踪来转换代码。例如,假设我们想要运行一个GraphModule并在运行时记录我们在节点上看到的torch.Tensor的形状和dtype属性。这可能看起来像:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
形状传播。这个类接受一个 `GraphModule`。
然后,它的 `propagate` 方法使用给定的参数逐节点执行 `GraphModule`。
在每个操作执行时,ShapeProp 类会存储每个操作输出值的形状和
元素类型在操作的 `Node` 的 `shape` 和 `dtype` 属性上。
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# 这是形状传播特有的代码。
# 你可以删除这个 `if` 分支,这样就变成了
# 一个通用的 GraphModule 解释器。
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
正如你所见,一个完整的FX解释器并不那么复杂,但它可以非常有用。为了简化使用这种模式,我们提供了Interpreter类,它包含了上述逻辑,并且可以通过方法重写来覆盖解释器执行的某些方面。
除了执行操作外,我们还可以通过将Proxy值传递给解释器来生成一个新的Graph。
同样,我们提供了Transformer类来涵盖这种模式。Transformer的行为类似于
Interpreter,但不是调用run方法从模块中获取具体的输出值,而是调用
Transformer.transform()方法来返回一个新的
GraphModule,该模块受到您安装的任何转换规则的约束,这些规则作为重写方法。
调试¶
介绍¶
在编写转换的过程中,我们的代码往往并不完全正确。 在这种情况下,我们可能需要进行一些调试。关键是从后往前进行:首先,检查调用生成模块的结果以验证或反驳正确性。然后,检查并调试生成的代码。最后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用的调试器。
转换创作中的常见陷阱¶
非确定性的
set迭代顺序。在 Python 中,set数据类型是无序的。使用set来包含像Node这样的对象集合,例如,可能会导致意外的非确定性。一个例子是迭代一组Node并将它们插入到Graph中。由于set数据类型是无序的,输出程序中的操作顺序将是不确定的,并且在程序的不同调用之间可能会发生变化。推荐的替代方法是使用dict数据类型,自 Python 3.7(以及 cPython 3.6)以来,它是 插入顺序。可以通过将需要去重的值存储在dict的键中来等效地使用dict。
检查模块的正确性¶
因为大多数深度学习模块的输出由浮点数
torch.Tensor 实例组成,检查两个
torch.nn.Module 的结果是否相等并不像进行简单的等式检查那样直接。为了说明这一点,让我们使用一个
例子:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# 想象一下我们在这里做了一些转换
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: 张量包含多个值时的布尔值是模糊的
"""
在这里,我们尝试使用 == 相等运算符检查两个深度学习模型的值是否相等。然而,由于该运算符返回的是张量而不是布尔值,并且浮点值的比较应使用误差范围(或epsilon)来考虑浮点运算的非交换性(更多详情请参见 这里),因此这种方法定义不明确。我们可以使用 torch.allclose() 来代替,它将为我们提供一个近似比较,考虑相对和绝对容差阈值:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中的第一个工具,用于检查转换后的模块是否与我们期望的参考实现表现一致。
调试生成的代码¶
因为FX在GraphModule上生成了forward()函数,使用传统的调试技术,如print语句或pdb,并不那么直接。幸运的是,我们有几种技术可以用来调试生成的代码。
使用 pdb¶
调用 pdb 进入正在运行的程序。尽管表示 Graph 的代码不在任何源文件中,我们仍然可以在调用前向传播时手动使用 pdb 进入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# 转换逻辑在这里
# <...>
# 返回新的模块
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# 当这一行在运行时执行时,我们将进入一个交互式的 `pdb` 提示符。我们可以使用 `step` 或 `s` 命令来
# 进入下一行的执行
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码¶
如果您想多次运行相同的代码,那么使用pdb逐步跳转到正确的代码可能会有些繁琐。在这种情况下,一种方法是简单地将生成的forward传递复制粘贴到您的代码中,并从那里进行检查。
# 假设 `traced` 是一个已经经过一些转换的 GraphModule
# 复制这段代码以备后用
print(traced)
# 打印从符号追踪生成的代码。输出如下:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# 子类化原始的 Module
class SubclassM(M):
def __init__(self):
super().__init__()
# 粘贴生成的 `forward` 函数(就是我们打印并复制的那个)到这里
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# 创建原始未追踪 Module 的实例。然后,创建一个带有复制的 `forward` 函数的 Module 实例。
# 现在我们可以比较原始版本和追踪版本的输出。
pre_trace = M()
post_trace = SubclassM()
使用 to_folder 函数从 GraphModule¶
GraphModule.to_folder() 是 GraphModule 中的一个方法,允许你将生成的 FX 代码转储到一个文件夹中。虽然通常只需将前向传播代码复制到代码中即可,如 打印生成的代码 所示,但使用 to_folder 可能更容易检查模块和参数。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上述示例后,我们可以查看
foo/module.py 中的代码并根据需要进行修改(例如添加 print
语句或使用 pdb)来调试生成的代码。
调试转换¶
既然我们已经确定了一个转换正在生成错误的代码,那么是时候调试这个转换本身了。首先,我们会检查文档中的符号追踪的局限性部分。一旦我们确认追踪是按预期工作的,目标就变成了找出在我们的GraphModule转换过程中出了什么问题。在编写转换中可能有一个快速的答案,但如果没有,有几种方法可以检查我们追踪的模块:
# 示例模块
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# 创建 `M` 的一个实例
m = M()
# 符号化地追踪 `M` 的一个实例(返回一个 GraphModule)。在这个例子中,我们只讨论如何检查一个
# GraphModule,所以我们没有展示任何示例转换,以保持简洁。
traced = symbolic_trace(m)
# 打印由追踪模块生成的代码。
print(traced)
# 生成的 `forward` 函数是:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# 打印内部图。
print(traced.graph)
# 这个打印输出返回:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# 打印内部图的表格表示。
traced.graph.print_tabular()
# 这给了我们:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add (x, y) {}
output output output (add,) {}
"""
```使用上述实用函数,我们可以在应用转换之前和之后比较我们的跟踪模块。有时,简单的视觉比较足以追踪到一个错误。如果仍然不清楚哪里出了问题,像pdb这样的调试器可以是一个很好的下一步。
基于上述示例,考虑以下代码:
# 示例用户定义函数
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# 从我们的跟踪模块中获取图
g = tracer_class().trace(module)
"""
对 `g` 的转换在这里进行
"""
return fx.GraphModule(module, g)
# 转换图
transformed = transform_graph(traced)
# 打印转换后的新代码。检查是否符合我们的预期
print(transformed)
使用上述示例,假设调用 print(traced)
显示我们的转换中存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb 会话。我们可以通过在
transform_graph(traced) 处中断,然后按 s 来“步入”对
transform_graph(traced) 的调用,从而查看转换过程中发生了什么。
我们也可以通过编辑 print_tabular 方法来打印图中的节点的不同属性,从而获得好运。(例如,我们可能想要查看节点的 input_nodes 和 users。)
可用的调试器¶
最常见的 Python 调试器是
pdb。你可以通过在命令行中输入
python -m pdb FILENAME.py 以“调试模式”启动你的程序,其中 FILENAME
是你想要调试的文件名。之后,你可以使用
pdb 调试器命令
逐步移动你的运行程序。通常在启动 pdb 时设置一个
断点(b LINE-NUMBER),然后调用 c 运行程序直到该点。这可以避免你必须
逐行执行(使用 s 或 n)以到达你想要检查的代码部分。或者,你可以在你想要中断的行之前写入
import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),你的程序将在运行时自动
进入调试模式。(换句话说,你可以直接在命令行中输入
python FILENAME.py 而不是
python -m pdb FILENAME.py。)一旦你在调试模式下运行你的文件,你可以逐步执行代码并使用某些命令检查你的程序的内部状态。网上有许多关于 pdb 的优秀教程,包括 RealPython 的
“使用 Pdb 进行 Python 调试”。
像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在你的 IDE 中,你可以选择 a) 通过在 IDE 中打开一个终端窗口(例如在 VSCode 中,选择 View → Terminal)使用 pdb,或者 b) 使用内置的调试器(通常是围绕 pdb 的图形界面)。
符号追踪的局限性¶
FX 使用一种 符号追踪(又名 符号执行)系统,
以捕获程序的语义并以可转换/可分析的形式表示。该系统是 追踪 的,因为它执行程序(实际上是一个
torch.nn.Module 或函数)以记录操作。它是
符号 的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 术语中称为 Proxy)。
尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。
动态控制流¶
符号追踪的主要限制是它目前不支持动态控制流。也就是说,循环或if语句,其中条件可能依赖于程序的输入值。
例如,让我们检查以下程序:
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
文件 "dyn.py",第 6 行,在 func_to_trace 中
if x.sum() > 0:
文件 "pytorch/torch/fx/proxy.py",第 155 行,在 __bool__ 中
return self.tracer.to_bool(self)
文件 "pytorch/torch/fx/proxy.py",第 85 行,在 to_bool 中
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if语句的条件依赖于x.sum()的值,而x.sum()的值又依赖于函数输入x的值。由于x可能会发生变化(例如,如果你将一个新的输入张量传递给跟踪函数),这就是动态控制流。回溯会通过你的代码向上追溯,以显示这种情况发生的位置。
静态控制流¶
另一方面,所谓的静态控制流是受支持的。静态控制流是指在调用过程中其值不会改变的循环或if语句。通常,在PyTorch程序中,这种控制流出现在根据超参数对模型架构进行决策的代码中。作为一个具体的例子:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# 这个if语句被称为静态控制流。
# 它的条件不依赖于任何输入值
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if-statement if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,并且具有该参数不同值的 MyModule 的不同实例具有不同的代码。这是一种有效的模式,由符号追踪支持。
许多动态控制流的实例在语义上是静态控制流。这些实例可以通过移除对输入值的数据依赖来支持符号追踪,例如通过将值移动到Module属性中,或在符号追踪期间将具体值绑定到参数上:
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # 失败!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含此代码的程序部分可以被跟踪为对方法(参见使用Tracer类自定义跟踪)或函数(参见wrap())的调用,而不是通过它们进行跟踪。
非torch函数¶
FX 使用 __torch_function__ 作为拦截调用的机制(有关此内容的更多信息,请参阅 技术概述)。一些函数,例如内置的 Python 函数或 math 模块中的函数,不受 __torch_function__ 的覆盖,但我们仍然希望在符号追踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
通过批次维度的大小对 `x` 进行归一化
"""
return x / sqrt(len(x))
# 这是有效的 Python 代码
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
文件 "sqrt.py",第 9 行,在 normalize 中
return x / sqrt(len(x))
文件 "pytorch/torch/fx/proxy.py",第 161 行,在 __len__ 中
引发 RuntimeError("'len' 在符号追踪中默认不支持。如果你想记录这个调用,请在模块作用域中调用 torch.fx.wrap('len')")
RuntimeError: 'len' 在符号追踪中默认不支持。如果你想记录这个调用,请在模块作用域中调用 torch.fx.wrap('len')
"""
错误告诉我们内置函数 len 不受支持。
我们可以使这样的函数在跟踪中作为直接调用记录,使用 wrap() API:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用Tracer类自定义跟踪¶
The Tracer 类是实现 symbolic_trace 的基础类。可以通过子类化 Tracer 来自定义跟踪行为,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# 在这里你可以覆盖各种方法
# 以自定义追踪。请参阅 `Tracer` API
# 参考
pass
# 让我们使用这个自定义追踪器来追踪这个模块
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() 返回一个 Graph。让我们将其包装在一个
# GraphModule 中以使其可运行
traced = torch.fx.GraphModule(mod, traced_graph)
叶模块¶
Leaf Modules 是那些在符号跟踪中作为调用出现而不是被跟踪的模块。默认的 Leaf Modules 集合是标准的 torch.nn 模块实例集合。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` 被保留为一个调用,然而 `submod` 被追踪了。
# 这是因为默认的“叶模块”集合包括了所有
# 标准的 `torch.nn` 模块。
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过重写Tracer.is_leaf_module()来自定义叶子模块的集合。
杂项¶
张量构造函数(例如
torch.zeros,torch.ones,torch.rand,torch.randn,torch.sparse_coo_tensor) 目前不可追踪。确定性构造函数(
zeros、ones)可以使用,并且它们生成的值将作为常量嵌入到跟踪中。这只有在这些构造函数的参数引用动态输入大小时才会出现问题。在这种情况下,ones_like或zeros_like可能是可行的替代方案。非确定性构造函数(
rand、randn)将在跟踪中嵌入一个随机值。这可能不是预期的行为。一种解决方法是将torch.randn包装在一个torch.fx.wrap函数中,并调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能会在未来的版本中修复。
类型注解
Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor)是支持的 并且将由符号追踪保留。Python 2风格的注释类型注解
# 类型: (torch.Tensor, int) -> torch.Tensor目前不支持。目前不支持在函数内对本地名称进行注解。
在
训练标志和子模块周围抓住了当使用像
torch.nn.functional.dropout这样的函数时,通常会将训练参数作为self.training传递进来。在 FX 追踪期间,这可能会被作为常量值固定下来。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: 张量不接近! 不匹配的元素: 15 / 15 (100.0%) 最大绝对差异: 1.6207983493804932 在索引 (0, 2) 处 (允许的最大差异为 1e-05) 最大相对差异: 1.0 在索引 (0, 0) 处 (允许的最大差异为 0.0001) """
然而,当使用标准的
nn.Dropout()子模块时,训练标志被封装,并且由于nn.Module对象模型的保留,可以进行更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由于这种差异,请考虑将那些与
training标志动态交互的模块标记为叶子模块。
API参考¶
- torch.fx.symbolic_trace(root, concrete_args=None)[源代码]¶
符号追踪 API
给定一个
nn.Module或函数实例root,此函数将返回一个通过记录在追踪root过程中看到的操作而构建的GraphModule。concrete_args允许你部分特化你的函数,无论是为了移除控制流还是数据结构。例如:
def f(a, b): if b == True: return a else: return a*2
由于控制流的存在的,FX通常无法跟踪此过程。然而,我们可以使用concrete_args来专门处理b的值以进行跟踪:
f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6
请注意,尽管您仍然可以传入不同的b值,但它们将被忽略。
我们也可以使用concrete_args来消除函数中的数据结构处理。这将使用pytrees来展平您的输入。为了避免过度专门化,对于不应该专门化的值,请传入fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- Parameters
根 (联合[torch.nn.Module, 可调用]) – 要追踪并转换为图表示的模块或函数。
concrete_args (可选[字典[str, 任意]]) – 要部分特化的输入
- Returns
从记录的操作创建的模块来自
root。- Return type
注意
此API的向后兼容性得到保证。
- torch.fx.wrap(fn_or_name)[源代码]¶
此函数可以在模块级范围内调用,以将 fn_or_name 注册为“叶函数”。 “叶函数”将在 FX 跟踪中保留为 CallFunction 节点,而不是被跟踪通过:
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap('my_custom_function') def fn_to_be_traced(x, y): # 在符号追踪时,对 my_custom_function 的调用将被插入到 # 图中,而不是对其进行追踪。 return my_custom_function(x, y)
这个函数也可以等价地用作装饰器:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
一个包装函数可以被视为一个“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在FX跟踪中保留为调用的函数,而不是被跟踪通过的函数。
- Parameters
fn_or_name (Union[str, Callable]) – 当调用时插入到图中的函数或全局函数的名称
注意
此API的向后兼容性得到保证。
- class torch.fx.GraphModule(*args, **kwargs)[源代码]¶
GraphModule 是一个由 fx.Graph 生成的 nn.Module。GraphModule 具有一个
graph属性,以及由该graph生成的code和forward属性。警告
当
graph被重新赋值时,code和forward将自动重新生成。然而,如果你编辑了graph的内容而没有重新赋值graph属性本身,你必须调用recompile()来更新生成的代码。注意
此API的向后兼容性得到保证。
- __init__(root, graph, class_name='GraphModule')[源代码]¶
构建一个GraphModule。
- Parameters
根 (联合[torch.nn.Module, 字典[字符串, 任意]) –
根可以是 nn.Module 实例或映射字符串到任意属性类型的字典。 如果根是一个模块,图的节点中任何对基于模块的对象的引用(通过限定名称)将在根的模块层次结构中从相应位置复制到 GraphModule 的模块层次结构中。 如果根是一个字典,节点中的限定名称将直接在字典的键中查找。字典映射到的对象将被复制到 GraphModule 的模块层次结构中的适当位置。图 (Graph) –
graph包含此GraphModule应使用的节点以进行代码生成类名 (str) –
name表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自GraphModule。设置此项可能有助于将其设置为root的原始名称或在您的转换上下文中具有意义的名称。
注意
此API的向后兼容性得到保证。
- add_submodule(target, m)[源代码]¶
将给定的子模块添加到
self。如果它们是
target的子路径,则在尚不存在的情况下安装空模块。- Parameters
- Returns
- 子模块是否可以被插入。对于
此方法返回True,链中的每个对象 由
target表示必须要么a) 还不存在, 或者b) 引用一个nn.Module(不是参数或其他属性)
- Return type
注意
此API的向后兼容性得到保证。
- delete_all_unused_submodules()[源代码]¶
删除
self中所有未使用的子模块。一个模块被认为是“已使用”的,如果以下任一条件为真: 1. 它有被使用的子模块 2. 它的forward通过
call_module节点直接调用 3. 它有一个非Module属性,该属性通过get_attr节点被使用此方法可以被调用来清理一个
nn.Module,而不需要手动调用delete_submodule在每个未使用的子模块上。注意
此API的向后兼容性得到保证。
- delete_submodule(target)[源代码]¶
从
self中删除给定的子模块。如果
target不是一个有效的目标,模块将不会被删除。- Parameters
目标 (str) – 新子模块的完全限定字符串名称 (参见
nn.Module.get_submodule中的示例,了解如何指定完全限定字符串。)- Returns
- 目标字符串是否引用了
我们想要删除的子模块。返回值为
False表示target不是对子模块的有效引用。
- Return type
注意
此API的向后兼容性得到保证。
- print_readable(print_output=True)[源代码]¶
返回为当前GraphModule及其子GraphModules生成的Python代码
警告
此API是实验性的,并且不向后兼容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]¶
Graph是 FX Intermediate Representation 中使用的主要数据结构。 它由一系列Node组成,每个节点代表调用点(或其他语法结构)。这些Node的列表共同构成一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m)
将生成以下图形:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1关于在
Graph中表示的操作的语义,请参见Node。注意
此API的向后兼容性得到保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[源代码]¶
插入一个
call_function节点到图中。一个call_function节点 表示对一个 Python 可调用对象的调用,由the_function指定。- Parameters
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 操作符、Python 函数,或者是
builtins或operator命名空间中的成员。args (可选[元组[参数, ...]]) – 要传递给被调用函数的定位参数。
kwargs (可选[字典[str, 参数]]) – 传递给被调用函数的键值参数
type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。
- Returns
新创建并插入的
call_function节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此API的向后兼容性得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[源代码]¶
在
Graph中插入一个call_methodNode。一个call_method节点 表示对args的第0个元素上的给定方法的调用。- Parameters
- Returns
新创建并插入的
call_method节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此API的向后兼容性得到保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[源代码]¶
在
Graph中插入一个call_moduleNode。一个call_module节点 表示对Module层次结构中某个Module的forward()函数的调用。- Parameters
module_name (str) – 在
Module层次结构中要调用的Module的限定名称。例如,如果被跟踪的Module有一个名为foo的子模块,该子模块又有一个名为bar的子模块,则应将限定名称foo.bar作为module_name传递以调用该模块。args (可选[元组[参数, ...]]) – 要传递给被调用方法的位置参数。请注意,这不应包括
self参数。kwargs (可选[字典[str, Argument]]) – 要传递给被调用方法的关键字参数
type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。
- Returns
新创建并插入的
call_module节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此API的向后兼容性得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[源代码]¶
创建一个
节点并将其添加到当前插入点的图中。 请注意,当前插入点可以通过Graph.inserting_before()和Graph.inserting_after()来设置。- Parameters
op (str) – 此节点的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在
Graph文档字符串中描述。args (可选[元组[参数, ...]]) – 是传递给此节点的参数元组。
kwargs (可选[字典[str, Argument]]) – 此节点的kwargs
名称 (可选[字符串]) – 为
节点指定的一个可选字符串名称。 这将影响在生成的Python代码中分配的值的名称。type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。
- Returns
新创建并插入的节点。
- Return type
注意
此API的向后兼容性得到保证。
- eliminate_dead_code()[源代码]¶
根据每个节点的用户数量以及节点是否具有副作用,从图中移除所有无效代码。在调用之前,图必须进行拓扑排序。
- Returns
图表是否由于传递而发生了变化。
- Return type
示例:
在死代码被消除之前,a 来自 a = x + 1 下面没有使用者,因此可以从图中消除而不会产生影响。
def forward(self, x): a = x + 1 return x + self.attr_1
在死代码被消除后,a = x + 1 已被移除,而其余的 forward 仍然保留。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法来避免移除具有副作用的节点(参见 Node.is_impure),但总体覆盖率非常差,因此除非你确定你的 FX 图完全由函数操作组成,否则你应该假设调用此方法是不可靠的。
注意
此API的向后兼容性得到保证。
- erase_node(to_erase)[源代码]¶
从
Graph中删除一个Node。如果在Graph中仍有该节点的用户,则会抛出异常。- Parameters
to_erase (节点) – 要从
Graph中删除的Node。
注意
此API的向后兼容性得到保证。
- get_attr(qualified_name, type_expr=None)[源代码]¶
在图中插入一个
get_attr节点。一个get_attrNode表示从Module层次结构中获取一个属性。- Parameters
qualified_name (str) – 要检索的属性的完全限定名称。 例如,如果被跟踪的模块有一个名为
foo的子模块,该子模块有一个名为bar的子模块,该子模块有一个名为baz的属性,则应将完全限定名称foo.bar.baz作为qualified_name传递。type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。
- Returns
新创建并插入的
get_attr节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此API的向后兼容性得到保证。
- graph_copy(g, val_map, return_output_node=False)[源代码]¶
将给定图中的所有节点复制到
self中。- Parameters
- Returns
如果
g有一个output节点,则self中的值现在等同于g中的输出值。否则为None。- Return type
可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]
注意
此API的向后兼容性得到保证。
- inserting_after(n=None)[源代码]¶
- Set the point at which create_node and companion methods will insert into the graph.
当在‘with’语句中使用时,这将临时设置插入点,并在with语句退出时恢复它:
with g.inserting_after(n): ... # 在节点n之后插入 ... # 插入点恢复到之前的状态 g.inserting_after(n) # 永久设置插入点
参数:
- n (Optional[Node]): The node before which to insert. If None this will insert after
整个图的开始。
- Returns:
一个资源管理器,将在
__exit__时恢复插入点。
注意
此API的向后兼容性得到保证。
- inserting_before(n=None)[源代码]¶
- Set the point at which create_node and companion methods will insert into the graph.
当在‘with’语句中使用时,这将临时设置插入点,并在with语句退出时恢复它:
with g.inserting_before(n): ... # 在节点 n 之前插入 ... # 插入点恢复到之前的状态 g.inserting_before(n) # 永久设置插入点
参数:
- n (Optional[Node]): The node before which to insert. If None this will insert before
整个图的开始。
- Returns:
一个资源管理器,将在
__exit__时恢复插入点。
注意
此API的向后兼容性得到保证。
- lint()[源代码]¶
对当前图进行各种检查,以确保其结构良好。特别是: - 检查节点是否具有正确的所有权(归属于此图) - 检查节点是否按拓扑顺序排列 - 如果此图归属于一个GraphModule,检查目标是否存在于该GraphModule中
注意
此API的向后兼容性得到保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[源代码]¶
将一个节点从一个图复制到另一个图中。
arg_transform需要将参数从节点所在的图转换到当前图。示例:# 将 `g` 中的所有节点复制到 `new_graph` 中 g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
- Parameters
- Return type
注意
此API的向后兼容性得到保证。
- property nodes: _node_list¶
获取构成此图的节点列表。
请注意,这个
节点列表表示是一个双向链表。在迭代期间进行变异(例如删除一个节点,添加一个节点)是安全的。- Returns
一个双向链表的节点。请注意,可以在该列表上调用
reversed来切换迭代顺序。
- on_generate_code(make_transformer)[源代码]¶
在生成Python代码时注册一个转换器函数
- Args:
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
一个返回代码转换器的函数,该函数由on_generate_code调用以获取代码转换器。
此函数还会将其输入作为当前注册的代码转换器(如果没有注册则为None),以防不希望覆盖它。这对于将代码转换器链接在一起非常有用。
- Returns:
一个上下文管理器,当在with语句中使用时,自动恢复之前注册的代码转换器。
示例:
gm: fx.GraphModule = ... # 这是一个我们想要注册的代码转换器。这个代码转换器在生成的torch.fx代码的最前面插入一个pdb导入和跟踪语句,以便使用PDB库进行手动调试。 def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # 注册`insert_pdb`,并覆盖当前注册的代码转换器(由lambda中的`_`给出): gm.graph.on_generate_code( lambda _: insert_pdb ) # 或者,注册一个代码转换器,该转换器首先通过现有的注册转换器运行`body`,然后通过`insert_pdb`运行: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # 进入pdb
此函数也可以用作上下文管理器,其好处是自动恢复之前注册的代码转换器:
# ... 继续自上一个示例 with gm.graph.on_generate_code(lambda _: insert_pdb): # 对 `gm` 进行更多操作... gm.recompile() gm(*inputs) # 进入 pdb # 现在恢复了之前的代码转换器(但带有 pdb 的 `gm` 代码仍然存在 - 这意味着你可以在这里使用 pdb 运行 `gm`,直到你 # 运行下一个 `recompile()`)。
警告
此API是实验性的,并且不向后兼容。
- output(result, type_expr=None)[源代码]¶
插入一个
outputNode到Graph中。一个output节点代表Python代码中的一个return语句。result是应该返回的值。- Parameters
结果 (参数) – 要返回的值。
type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此API的向后兼容性得到保证。
- placeholder(name, type_expr=None, default_value)[源代码]¶
在图中插入一个
占位符节点。一个占位符表示一个函数输入。- Parameters
名称 (字符串) – 输入值的名称。这对应于此
Graph表示的函数的定位参数的名称。type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。在某些情况下,这是正确代码生成所必需的(例如,当函数随后用于TorchScript编译时)。
default_value (任意) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定该参数没有默认值。
- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此API的向后兼容性得到保证。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[源代码]¶
Node是表示Graph中各个操作的数据结构。在大多数情况下,节点表示对各种实体的调用站点,例如操作符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个Node都有一个由其op属性指定的函数。每个op值的Node语义如下:placeholder表示一个函数输入。name属性指定此值将采用的名称。target同样表示参数的名称。args包含以下之一:1) 无,或 2) 表示函数输入默认参数的单个参数。kwargs不关心。占位符对应于图输出中的函数参数(例如x)。get_attr从模块层次结构中检索一个参数。name是同样地,获取结果被分配到的名称。target是参数在模块层次结构中的完全限定名称。args和kwargs是无关紧要的call_function将一个自由函数应用于某些值。name同样是要赋值的值的名称。target是要应用的函数。args和kwargs表示函数的参数,遵循 Python 调用约定。call_module在模块层次结构的forward()方法中应用一个模块到给定的参数。name与之前相同。target是模块层次结构中要调用的模块的完全限定名称。args和kwargs表示调用模块的参数,不包括 self 参数。call_method调用一个值上的方法。name类似。target是要应用于self参数的方法的字符串名称。args和kwargs表示要调用模块的参数,包括 self 参数output包含 traced 函数在其args[0]属性中的输出。这对应于 Graph 打印输出中的“return”语句。
注意
此API的向后兼容性得到保证。
- property all_input_nodes: List[节点]¶
返回所有作为此节点输入的节点。这相当于迭代
args和kwargs,并且只收集那些是节点的值。- Returns
在此
Node的args和kwargs中出现的Nodes列表,按该顺序排列。
- append(x)[源代码]¶
在图的节点列表中,在此节点后插入
x。 等同于self.next.prepend(x)- Parameters
x (节点) – 要放在此节点之后的节点。必须是同一图的成员。
注意
此API的向后兼容性得到保证。
- property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, 节点, str, int, float, bool, complex, 数据类型, 张量, 设备, 内存格式, 布局, OpOverload]], ...]¶
这个
Node的参数元组。参数的解释取决于节点的操作码。更多信息请参阅Node文档字符串。允许对此属性进行赋值。在赋值时,所有使用和用户的记录都会自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[源码]¶
返回一个描述性的字符串表示形式。
此方法可以不带参数使用,作为调试工具。
此函数也在
Graph的__str__方法中内部使用。placeholder_names和maybe_return_typename中的字符串共同构成了此图所包围的 GraphModule 中自动生成的forward函数的签名。placeholder_names和maybe_return_typename不应在其他情况下使用。- Parameters
- Returns
- 如果1) 我们使用
format_node作为内部辅助函数 在
Graph的__str__方法中,并且 2)self是一个占位符节点,返回None。否则,返回当前节点的描述性字符串表示。
- 如果1) 我们使用
- Return type
注意
此API的向后兼容性得到保证。
- insert_arg(idx, arg)[源代码]¶
在参数列表中插入一个带有指定索引的位置参数。
- Parameters
idx (int) – 在
self.args中要插入元素的索引。arg (参数) – 要插入到
args中的新参数值
注意
此API的向后兼容性得到保证。
- is_impure()[源代码]¶
返回此操作是否为不纯操作,即如果其操作是占位符或输出,或者如果调用的是不纯的call_function或call_module。
- Returns
如果操作是纯操作或非纯操作。
- Return type
警告
此API是实验性的,并且不向后兼容。
- property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, 节点, str, int, float, bool, complex, 数据类型, 张量, 设备, 内存格式, 布局, OpOverload]]]¶
传递给此
节点的关键字参数字典。参数的解释取决于节点的操作码。更多信息请参阅节点文档字符串。允许对此属性进行赋值。在赋值时,所有使用和用户的记录都会自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[源代码]¶
返回标准化后的参数给Python目标。这意味着 args/kwargs 将被匹配到模块/函数的签名,并且如果 normalize_to_only_use_kwargs 为真,则按位置顺序返回仅包含kwargs的参数。 同时填充默认值。不支持仅位置参数或可变参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 以消除重载的歧义。
- Parameters
根 (torch.nn.Module) – 要解析模块目标的模块。
arg_types (可选[元组[任意]]) – 参数的参数类型元组
kwarg_types (可选[字典[str, 任意]]) – 关键字参数的类型字典
normalize_to_only_use_kwargs (布尔值) – 是否规范化为仅使用关键字参数。
- Returns
返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None。
- Return type
可选[ArgsKwargsPair]
警告
此API是实验性的,并且不向后兼容。
- prepend(x)[源代码]¶
在图的节点列表中,在此节点之前插入 x。示例:
之前: p -> self bx -> x -> ax 之后: p -> x -> self bx -> ax
- Parameters
x (节点) – 要放在此节点之前的节点。必须是同一图的成员。
注意
此API的向后兼容性得到保证。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[源代码]¶
将Graph中所有使用的
self替换为Nodereplace_with。- Parameters
- Returns
此更改所涉及的节点列表。
- Return type
注意
此API的向后兼容性得到保证。
- replace_input_with(old_input, new_input)[源代码]¶
遍历
self的输入节点,并将所有old_input的实例替换为new_input。注意
此API的向后兼容性得到保证。
- property stack_trace: Optional[str]¶
返回在追踪过程中记录的Python堆栈跟踪,如果有的话。 当使用fx.Tracer进行追踪时,此属性通常由Tracer.create_proxy填充。为了在调试目的下记录追踪期间的堆栈跟踪,请在Tracer实例上设置record_stack_traces = True。 当使用dynamo进行追踪时,此属性将默认由OutputGraph.create_proxy填充。
stack_trace 将在字符串的末尾包含最内层的帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源代码]¶
Tracer是实现torch.fx.symbolic_trace符号追踪功能的类。调用symbolic_trace(m)等同于调用Tracer().trace(m)。Tracer 可以被继承以覆盖跟踪过程中各种行为。可以覆盖的不同行为在该类的各个方法的文档字符串中进行了描述。
注意
此API的向后兼容性得到保证。
- call_module(m, forward, args, kwargs)[源代码]¶
指定此
Tracer在遇到对nn.Module实例的调用时的行为的方法。默认情况下,行为是检查被调用的模块是否为叶模块,通过
is_leaf_module。如果是,则发出一个引用m的call_module节点在Graph中。否则,正常调用Module,跟踪其forward函数中的操作。此方法可以被重写以——例如——创建嵌套的跟踪GraphModules,或您希望在跨越
Module边界时所需的任何其他行为。- Parameters
m (模块) – 正在发出调用的模块
forward (可调用对象) – 要调用的
Module的 forward() 方法args (元组) – 模块调用点的参数
kwargs (字典) – 模块调用点的kwargs
- Returns
Module 调用的返回值。如果发出了一个
call_module节点,则这是一个Proxy值。否则,它是从Module调用返回的任何值。- Return type
注意
此API的向后兼容性得到保证。
- create_arg(a)[源代码]¶
指定在准备值以用作
Graph中节点的参数时,跟踪行为的指定方法。默认情况下,行为包括:
遍历集合类型(例如元组、列表、字典)并对元素递归调用
create_args。给定一个代理对象,返回对底层 IR
Node的引用给定一个非代理张量对象,为各种情况生成IR:
对于一个参数,生成一个引用该参数的
get_attr节点对于非参数张量,将张量存储在一个特殊属性中,该属性引用该属性。
此方法可以被重写以支持更多类型。
- Parameters
a (任意) – 作为
Argument在Graph中发出的值。- Returns
将值
a转换为适当的Argument- Return type
可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]
注意
此API的向后兼容性得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[源代码]¶
创建与
root模块签名相对应的placeholder节点。此方法内省root的签名并相应地发出这些节点,同时支持*args和**kwargs。警告
此API是实验性的,并且不向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)¶
插入一个图节点,给定目标、参数、关键字参数和名称。
此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录就地操作。
注意
此API的向后兼容性得到保证。
- Return type
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)¶
从给定的参数创建一个节点,然后返回包装在代理对象中的节点。
如果 kind = ‘placeholder’,那么我们正在创建一个表示函数参数的节点。如果我们需要编码一个默认参数,我们使用
args元组。对于placeholder节点,args通常是空的。注意
此API的向后兼容性得到保证。
- getattr(attr, attr_val, parameter_proxy_cache)[源代码]¶
指定在调用
nn.Module实例时,当我们对Tracer调用 getattr 时的行为的函数。默认情况下,行为是返回属性的代理值。它还会将代理值存储在
parameter_proxy_cache中,以便将来的调用将重用该代理,而不是创建一个新的代理。此方法可以被重写以——例如——在查询参数时不返回代理。
- Parameters
- Returns
getattr 调用的返回值。
警告
此API是实验性的,并且不向后兼容。
- is_leaf_module(m, module_qualified_name)[源代码]¶
指定给定的
nn.Module是否为“叶子”模块的方法。叶子模块是出现在IR中的原子单元,通过
call_module调用引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块是叶子模块。所有其他模块都会被追踪,并且它们的组成操作会被记录,除非通过此参数另行指定。- Parameters
- Return type
注意
此API的向后兼容性得到保证。
- iter(obj)¶
- Called when a proxy object is being iterated over, such as
当用于控制流程时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。
注意
此API的向后兼容性得到保证。
- Return type
- keys(obj)¶
- Called when a proxy object is has the keys() method called.
这是在代理上调用 ** 时发生的情况。如果 ** 应该在您的自定义跟踪器中工作,则应返回一个迭代器。
注意
此API的向后兼容性得到保证。
- Return type
- path_of_module(mod)[源代码]¶
用于在
root的模块层次结构中查找mod的限定名称的辅助方法。例如,如果root有一个名为foo的子模块,该子模块又有一个名为bar的子模块,将bar传递给此函数将返回字符串 “foo.bar”。注意
此API的向后兼容性得到保证。
- to_bool(obj)¶
- Called when a proxy object is being converted to a boolean, such as
当用于控制流程时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
此API的向后兼容性得到保证。
- Return type
- class torch.fx.Proxy(node, tracer=None)[源代码]¶
Proxy对象是Node包装器,它们在符号追踪过程中流经程序,并记录所有操作(torch函数调用、方法调用、运算符)到不断增长的 FX 图中。如果你正在进行图变换,你可以将你自己的
Proxy方法包装在一个原始的Node周围,这样你就可以使用重载的运算符向Graph添加额外的东西。Proxy对象不能被迭代。换句话说,如果在一个循环中使用Proxy或在函数参数中使用*args/**kwargs,符号追踪器将会抛出错误。有两种主要的方法来解决这个问题: 1. 将不可追踪的逻辑提取到一个顶层函数中,并使用
fx.wrap对其进行包装。 2. 如果控制流是静态的(即循环次数基于某些超参数),代码可以保留在原位并重构为如下形式:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关Proxy内部更详细的描述,请查看torch/fx/OVERVIEW.md中的“Proxy”部分
注意
此API的向后兼容性得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源代码]¶
解释器逐节点执行FX图。这种模式可以用于许多方面,包括编写代码转换以及分析过程。
Interpreter类中的方法可以被重写以自定义执行行为。以下是可重写方法的调用层次结构图:
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想要将所有实例的
torch.neg替换为torch.sigmoid并且反之亦然(包括它们的Tensor方法等效项)。我们可以像这样子类化 Interpreter:```python class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid()) ```- Parameters
模块 (torch.nn.Module) – 要执行的模块
garbage_collect_values (bool) – 是否在模块执行过程中删除值在其最后一次使用后。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看
Interpreter.env属性来检查执行中的所有中间值。graph (可选[Graph]) – 如果传递,解释器将执行此图而不是 module.graph,使用提供的 module 参数来满足任何状态请求。
注意
此API的向后兼容性得到保证。
- boxed_run(args_list)[源代码]¶
通过解释运行模块并返回结果。 这使用了“装箱”调用约定,其中您传递一个参数列表,这些参数将由解释器清除。 这确保了输入张量被及时释放。
注意
此API的向后兼容性得到保证。
- call_function(target, args, kwargs)[源代码]¶
执行一个
call_function节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 本次调用的关键字参数字典
- Return type
- Return
任意:函数调用返回的值
注意
此API的向后兼容性得到保证。
- call_method(target, args, kwargs)[源代码]¶
执行一个
call_method节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Return type
- Return
任意:方法调用返回的值
注意
此API的向后兼容性得到保证。
- call_module(target, args, kwargs)[源代码]¶
执行一个
call_module节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Return type
- Return
任意: 模块调用返回的值
注意
此API的向后兼容性得到保证。
- fetch_args_kwargs_from_env(n)[源代码]¶
从当前执行环境中获取节点
n的args和kwargs的具体值。- Parameters
n (节点) – 要获取
args和kwargs的节点。- Returns
args和kwargs使用具体的值作为n。- Return type
元组[元组, 字典]
注意
此API的向后兼容性得到保证。
- fetch_attr(target)[源代码]¶
从
self.module的Module层次结构中获取一个属性。- Parameters
目标 (str) – 要获取的属性的完全限定名称
- Returns
属性的值。
- Return type
任意
注意
此API的向后兼容性得到保证。
- get_attr(target, args, kwargs)[源代码]¶
执行一个
get_attr节点。将从self.module的Module层次结构中检索属性值。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Returns
被检索到的属性的值
- Return type
任意
注意
此API的向后兼容性得到保证。
- map_nodes_to_values(args, n)[源代码]¶
递归地遍历
args并在当前执行环境中查找每个Node的具体值。- Parameters
args (参数) – 在其中查找具体值的数据结构
n (节点) – 参数
args所属的节点。这仅用于错误报告。
- Return type
可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]
注意
此API的向后兼容性得到保证。
- output(target, args, kwargs)[源代码]¶
执行一个
output节点。这实际上只是检索由output节点引用的值并返回它。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Returns
输出节点引用的返回值
- Return type
任意
注意
此API的向后兼容性得到保证。
- placeholder(target, args, kwargs)[源代码]¶
执行一个
placeholder节点。请注意,这是有状态的:Interpreter维护了一个内部迭代器,用于遍历传递给run的参数,并且此方法返回该迭代器的next()。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Returns
获取到的参数值。
- Return type
任意
注意
此API的向后兼容性得到保证。
- class torch.fx.Transformer(module)[源代码]¶
Transformer是一种特殊类型的解释器,它生成一个新的Module。它公开了一个transform()方法,该方法返回转换后的Module。Transformer不需要像Interpreter那样运行参数。Transformer完全以符号方式工作。示例
假设我们想要将所有实例的
torch.neg替换为torch.sigmoid以及反之亦然(包括它们的Tensor方法等价物)。我们可以像这样子类化Transformer:class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- Parameters
模块 (GraphModule) – 要被转换的
模块。
注意
此API的向后兼容性得到保证。
- get_attr(target, args, kwargs)[源代码]¶
执行一个
get_attr节点。在Transformer中,这被重写以在输出图中插入一个新的get_attr节点。- Parameters
目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此次调用的关键字参数字典
- Return type
注意
此API的向后兼容性得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[源代码]¶
匹配图模块(
gm)图中所有可能的非重叠操作符及其数据依赖集(pattern),然后使用另一个子图(replacement)替换每个匹配的子图。- Parameters
- Returns
表示在原始图中匹配到
pattern的位置的Match对象列表。如果没有匹配项,则该列表为空。Match定义为:class Match(NamedTuple): # 匹配找到的节点 anchor: Node # 将模式子图中的节点映射到较大图中的节点 nodes_map: Dict[Node, Node]
- Return type
列表[匹配]
示例:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码将首先在
traced_module的forward方法中匹配pattern。模式匹配是基于使用-定义关系进行的,而不是基于节点名称。例如,如果你在pattern中有p = torch.cat([a, b]),你可以匹配原始forward函数中的m = torch.cat([a, b]),尽管变量名称不同(p与m)。在
pattern中的return语句仅根据其值进行匹配;它可能与较大图中的return语句匹配,也可能不匹配。换句话说,模式不必延伸到较大图的末尾。当模式匹配时,它将从较大的函数中移除,并替换为
replacement。如果在较大的函数中存在多个pattern的匹配项,每个不重叠的匹配项都将被替换。在匹配重叠的情况下,首先找到的重叠匹配项将被替换。(这里的“首先”是根据节点使用-定义关系的拓扑排序定义的。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的内容。)需要注意的一个重要事项是,
pattern可调用对象的参数必须在可调用对象本身中使用,并且replacement可调用对象的参数必须与模式匹配。第一个规则就是为什么在上面的代码块中,forward函数有参数x, w1, w2,但pattern函数只有参数w1, w2。pattern不使用x,所以它不应该将x指定为参数。作为第二个规则的示例,考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
与
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement需要与pattern相同的参数数量(x和y),即使参数y在replacement中未被使用。在调用
subgraph_rewriter.replace_pattern之后,生成的 Python 代码如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此API的向后兼容性得到保证。