Shortcuts

torch.fx

概述

FX 是一个供开发者使用的工具包,用于转换 nn.Module 实例。FX 由三个主要组件组成:一个 符号追踪器, 一个 中间表示,以及 Python 代码生成。以下是这些组件的实际演示:

import torch
# 简单的模块示例
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# 符号追踪前端 - 捕获模块的语义
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# 高级中间表示(IR)- 图表示
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# 代码生成 - 有效的Python代码
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器执行 Python 代码的“符号执行”。它通过代码传递称为代理的假值。这些代理的操作会被记录下来。有关符号追踪的更多信息,可以在 symbolic_trace()Tracer 文档中找到。

中间表示(intermediate representation)是用于存储在符号追踪过程中记录的操作的容器。它由一系列节点组成,这些节点表示函数输入、调用点(对函数、方法或torch.nn.Module实例的调用)以及返回值。有关IR的更多信息,可以在Graph的文档中找到。IR是应用转换的格式。

Python代码生成是使FX成为Python到Python(或模块到模块)转换工具包的原因。对于每个Graph IR,我们可以创建与Graph语义匹配的有效Python代码。此功能被封装在GraphModule中,它是一个torch.nn.Module实例,持有Graph以及从Graph生成的forward方法。

总的来说,这一系列组件(符号追踪 -> 中间表示 -> 转换 -> Python代码生成)构成了FX的Python到Python转换管道。此外,这些组件可以单独使用。例如,符号追踪可以单独用于捕获代码的形式以进行分析(而非转换)目的。代码生成可以用于以编程方式生成模型,例如从配置文件生成。FX有很多用途!

可以在 示例 仓库中找到几个示例转换。

编写转换

什么是FX变换?本质上,它是一个看起来像这样的函数。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # 步骤1:获取表示 `m` 中代码的图

    # 注意:torch.fx.symbolic_trace 是调用 fx.Tracer.trace 并构造 GraphModule 的包装器。我们将在我们的变换中将其拆分,以允许调用者自定义跟踪行为。
    graph : torch.fx.Graph = tracer_class().trace(m)

    # 步骤2:修改此图或创建一个新图
    graph = ...

    # 步骤3:构造一个要返回的模块
    return torch.fx.GraphModule(m, graph)

您的转换将接收一个torch.nn.Module,从中获取一个Graph,进行一些修改,并返回一个新的torch.nn.Module。您应该将您的FX转换返回的torch.nn.Module视为与常规torch.nn.Module相同——您可以将其传递给另一个FX转换,可以将其传递给TorchScript,或者可以运行它。确保您的FX转换的输入和输出是一个torch.nn.Module将允许组合性。

注意

也可以修改现有的 GraphModule,而不是创建一个新的,如下所示:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # 修改 gm.graph
    # <...>

    # 从其 Graph 重新编译 `gm` 的 forward() 方法
    gm.recompile()

    return gm

请注意,您必须调用 GraphModule.recompile() 以使生成的 forward() 方法与修改后的 Graph 同步。

假设你已经传入了一个已经被追踪到Graphtorch.nn.Module,现在有两种主要的方法可以用来构建一个新的Graph

图的快速入门

图的语义的完整处理可以在Graph文档中找到,但我们在这里将介绍基础知识。一个Graph是一个表示GraphModule上的方法的数据结构。这需要的信息是:

  • 方法的输入是什么?

  • 方法内部运行了哪些操作?

  • 方法的输出(即返回)值是什么?

这三个概念都用 Node 实例表示。 让我们通过一个简短的例子来看看这意味着什么:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

在这里,我们定义了一个模块 MyModule 用于演示目的,实例化它, 符号化地追踪它,然后调用 Graph.print_tabular() 方法来打印 出一个表格,显示这个 Graph 的节点:

操作码

名称

目标

参数

关键字参数

占位符

x

x

()

{}

获取属性

线性权重

线性.权重

()

{}

调用函数

add_1

<内置函数 add>

(x, linear_weight)

{}

调用模块

线性_1

线性

(add_1,)

{}

调用方法

relu_1

ReLU

(linear_1,)

{}

调用函数

sum_1

<内置方法 sum …>

(relu_1,)

{‘dim’: -1}

调用函数

topk_1

<内置方法 topk …>

(sum_1, 3)

{}

输出

输出

输出

(topk_1,)

{}

我们可以使用这些信息来回答我们上面提出的问题。

  • 方法的输入是什么?在FX中,方法输入是通过特殊的placeholder节点指定的。在这种情况下,我们有一个placeholder节点,其targetx,这意味着我们有一个名为x的单一(非self)参数。

  • 方法中的操作是什么?get_attrcall_functioncall_modulecall_method 节点 表示方法中的操作。所有这些的语义的完整处理可以在 Node 文档中找到。

  • 方法的返回值是什么?在Graph中,返回值由一个特殊的output节点指定。

鉴于我们现在了解了代码在FX中的基本表示方式,我们可以探讨如何编辑一个Graph

图操作

直接图操作

构建这个新的Graph的一种方法是直接操作旧的图。为了帮助实现这一点,我们可以简单地获取从符号追踪中得到的Graph并对其进行修改。例如,假设我们希望将torch.add()调用替换为torch.mul()调用。

import torch
import torch.fx

# 示例模块
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX 将其 Graph 表示为一个有序的节点列表,
    # 因此我们可以遍历它们。
    for node in graph.nodes:
        # 检查我们是否在调用一个函数(例如:
        # torch.add)
        if node.op == 'call_function':
            # target 属性是 call_function 调用的函数。
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # 进行一些检查以确保
                 # Graph 是良构的。

    return fx.GraphModule(m, graph)

我们还可以进行更复杂的Graph重写,例如删除或追加节点。为了帮助这些转换,FX提供了一些用于转换图的实用函数,可以在Graph文档中找到。下面是一个使用这些API追加torch.relu()调用的示例。

# 指定插入点。在此范围内添加到
# Graph中的任何节点都将插入到`node`之后
with traced.graph.inserting_after(node):
    # 插入一个新的`call_function`节点,调用`torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # 我们希望所有使用`node`值的地方
    # 现在使用我们在`relu`调用后添加的值。
    # 我们使用`replace_all_uses_with` API来实现这一点。
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单变换,您也可以使用子图重写器。

使用 replace_pattern() 进行子图重写

FX 还提供了在直接图形操作之上的另一层自动化。 replace_pattern() API 本质上是一个用于编辑 Graph 的“查找/替换”工具。它允许你指定一个 patternreplacement 函数 并且它会跟踪这些函数,找到 pattern 图形中操作组的实例,并用 replacement 图形的副本来替换这些实例。这可以帮助大大自动化繁琐的图形操作代码,当转换变得更加复杂时,这些代码可能会变得难以管理。

代理/重追溯

另一种操作 Graph 的方法是重用符号追踪中使用的 Proxy 机制。例如,假设我们想要编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是执行必要的图重写,在 F.relu 之后插入比较和乘法,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象来自动记录操作到 Graph 中来自动化这个过程。

要使用此方法,我们将要插入的操作编写为常规的 PyTorch 代码,并使用 Proxy 对象作为参数调用该代码。 这些 Proxy 对象将捕获对其执行的操作,并将它们附加到 Graph 中。

# 注意,这个分解规则可以像普通的Python一样阅读
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    将`model`分解为更小的组成操作。
    目前,这仅支持将ReLU分解为其
    数学定义:(x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # 通过用代理包装参数,
            # 我们可以分派到适当的
            # 分解规则,并通过符号化跟踪将其隐式添加到图表中。
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # 对`Proxy`的操作总是产生新的`Proxy`,并且
            # 我们的分解规则的返回值也不例外。
            # 我们需要从`Proxy`中提取底层的`Node`
            # 以便在后续的转换迭代中使用它。
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # 默认情况:我们没有这个节点的分解规则,
            # 所以只需将节点复制到新图中。
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用Proxy还允许您将重写规则指定为原生Python代码。对于需要大量重写规则的转换(如vmap或grad),这通常可以提高规则的可读性和可维护性。请注意,虽然调用Proxy时我们也传递了一个指向底层变量graph的跟踪器。这样做是为了防止如果图中的操作是n元操作(例如,add是一个二元运算符),调用Proxy不会创建多个图跟踪器实例,这可能导致意外的运行时错误。我们建议使用这种方法,特别是当底层操作不能安全地假设为单一时。

使用Proxy进行Graph操作的示例可以在 这里找到。

解释器模式

在FX中,一个有用的代码组织模式是遍历一个Graph中的所有Node并执行它们。这可以用于多种用途,包括对流经图的值进行运行时分析或通过使用Proxy进行重跟踪来转换代码。例如,假设我们想要运行一个GraphModule并在运行时记录我们在节点上看到的torch.Tensor的形状和dtype属性。这可能看起来像:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    形状传播。这个类接受一个 `GraphModule`。
    然后,它的 `propagate` 方法使用给定的参数逐节点执行 `GraphModule`。
    在每个操作执行时,ShapeProp 类会存储每个操作输出值的形状和
    元素类型在操作的 `Node` 的 `shape` 和 `dtype` 属性上。
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # 这是形状传播特有的代码。
            # 你可以删除这个 `if` 分支,这样就变成了
            # 一个通用的 GraphModule 解释器。
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

正如你所见,一个完整的FX解释器并不那么复杂,但它可以非常有用。为了简化使用这种模式,我们提供了Interpreter类,它包含了上述逻辑,并且可以通过方法重写来覆盖解释器执行的某些方面。

除了执行操作外,我们还可以通过将Proxy值传递给解释器来生成一个新的Graph。 同样,我们提供了Transformer类来涵盖这种模式。Transformer的行为类似于 Interpreter,但不是调用run方法从模块中获取具体的输出值,而是调用 Transformer.transform()方法来返回一个新的 GraphModule,该模块受到您安装的任何转换规则的约束,这些规则作为重写方法。

解释器模式示例

调试

介绍

在编写转换的过程中,我们的代码往往并不完全正确。 在这种情况下,我们可能需要进行一些调试。关键是从后往前进行:首先,检查调用生成模块的结果以验证或反驳正确性。然后,检查并调试生成的代码。最后,调试导致生成代码的转换过程。

如果您不熟悉调试器,请参阅辅助部分 可用的调试器

转换创作中的常见陷阱

  • 非确定性的 set 迭代顺序。在 Python 中,set 数据类型是无序的。使用 set 来包含像 Node 这样的对象集合,例如,可能会导致意外的非确定性。一个例子是迭代一组 Node 并将它们插入到 Graph 中。由于 set 数据类型是无序的,输出程序中的操作顺序将是不确定的,并且在程序的不同调用之间可能会发生变化。推荐的替代方法是使用 dict 数据类型,自 Python 3.7(以及 cPython 3.6)以来,它是 插入顺序。可以通过将需要去重的值存储在 dict 的键中来等效地使用 dict

检查模块的正确性

因为大多数深度学习模块的输出由浮点数 torch.Tensor 实例组成,检查两个 torch.nn.Module 的结果是否相等并不像进行简单的等式检查那样直接。为了说明这一点,让我们使用一个 例子:

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # 想象一下我们在这里做了一些转换
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: 张量包含多个值时的布尔值是模糊的
"""

在这里,我们尝试使用 == 相等运算符检查两个深度学习模型的值是否相等。然而,由于该运算符返回的是张量而不是布尔值,并且浮点值的比较应使用误差范围(或epsilon)来考虑浮点运算的非交换性(更多详情请参见 这里),因此这种方法定义不明确。我们可以使用 torch.allclose() 来代替,它将为我们提供一个近似比较,考虑相对和绝对容差阈值:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中的第一个工具,用于检查转换后的模块是否与我们期望的参考实现表现一致。

调试生成的代码

因为FX在GraphModule上生成了forward()函数,使用传统的调试技术,如print语句或pdb,并不那么直接。幸运的是,我们有几种技术可以用来调试生成的代码。

使用 pdb

调用 pdb 进入正在运行的程序。尽管表示 Graph 的代码不在任何源文件中,我们仍然可以在调用前向传播时手动使用 pdb 进入它。

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # 转换逻辑在这里
    # <...>

    # 返回新的模块
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# 当这一行在运行时执行时,我们将进入一个交互式的 `pdb` 提示符。我们可以使用 `step` 或 `s` 命令来
# 进入下一行的执行
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用 to_folder 函数从 GraphModule

GraphModule.to_folder()GraphModule 中的一个方法,允许你将生成的 FX 代码转储到一个文件夹中。虽然通常只需将前向传播代码复制到代码中即可,如 打印生成的代码 所示,但使用 to_folder 可能更容易检查模块和参数。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看 foo/module.py 中的代码并根据需要进行修改(例如添加 print 语句或使用 pdb)来调试生成的代码。

调试转换

既然我们已经确定了一个转换正在生成错误的代码,那么是时候调试这个转换本身了。首先,我们会检查文档中的符号追踪的局限性部分。一旦我们确认追踪是按预期工作的,目标就变成了找出在我们的GraphModule转换过程中出了什么问题。在编写转换中可能有一个快速的答案,但如果没有,有几种方法可以检查我们追踪的模块:

```python
# 示例模块
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# 创建 `M` 的一个实例
m = M()

# 符号化地追踪 `M` 的一个实例(返回一个 GraphModule)。在这个例子中,我们只讨论如何检查一个
# GraphModule,所以我们没有展示任何示例转换,以保持简洁。
traced = symbolic_trace(m)

# 打印由追踪模块生成的代码。
print(traced)
# 生成的 `forward` 函数是:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# 打印内部图。
print(traced.graph)
# 这个打印输出返回:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# 打印内部图的表格表示。
traced.graph.print_tabular()
# 这给了我们:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add       (x, y)  {}
output         output  output                   (add,)  {}
"""
```

使用上述实用函数,我们可以在应用转换之前和之后比较我们的跟踪模块。有时,简单的视觉比较足以追踪到一个错误。如果仍然不清楚哪里出了问题,像pdb这样的调试器可以是一个很好的下一步。

基于上述示例,考虑以下代码:

# 示例用户定义函数
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # 从我们的跟踪模块中获取图
    g = tracer_class().trace(module)

    """
    对 `g` 的转换在这里进行
    """

    return fx.GraphModule(module, g)

# 转换图
transformed = transform_graph(traced)

# 打印转换后的新代码。检查是否符合我们的预期
print(transformed)

使用上述示例,假设调用 print(traced) 显示我们的转换中存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 处中断,然后按 s 来“步入”对 transform_graph(traced) 的调用,从而查看转换过程中发生了什么。

我们也可以通过编辑 print_tabular 方法来打印图中的节点的不同属性,从而获得好运。(例如,我们可能想要查看节点的 input_nodesusers。)

可用的调试器

最常见的 Python 调试器是 pdb。你可以通过在命令行中输入 python -m pdb FILENAME.py 以“调试模式”启动你的程序,其中 FILENAME 是你想要调试的文件名。之后,你可以使用 pdb 调试器命令 逐步移动你的运行程序。通常在启动 pdb 时设置一个 断点(b LINE-NUMBER),然后调用 c 运行程序直到该点。这可以避免你必须 逐行执行(使用 sn)以到达你想要检查的代码部分。或者,你可以在你想要中断的行之前写入 import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),你的程序将在运行时自动 进入调试模式。(换句话说,你可以直接在命令行中输入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦你在调试模式下运行你的文件,你可以逐步执行代码并使用某些命令检查你的程序的内部状态。网上有许多关于 pdb 的优秀教程,包括 RealPython 的 “使用 Pdb 进行 Python 调试”

像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在你的 IDE 中,你可以选择 a) 通过在 IDE 中打开一个终端窗口(例如在 VSCode 中,选择 View → Terminal)使用 pdb,或者 b) 使用内置的调试器(通常是围绕 pdb 的图形界面)。

符号追踪的局限性

FX 使用一种 符号追踪(又名 符号执行)系统, 以捕获程序的语义并以可转换/可分析的形式表示。该系统是 追踪 的,因为它执行程序(实际上是一个 torch.nn.Module 或函数)以记录操作。它是 符号 的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 术语中称为 Proxy)。

尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。

动态控制流

符号追踪的主要限制是它目前不支持动态控制流。也就是说,循环或if语句,其中条件可能依赖于程序的输入值。

例如,让我们检查以下程序:

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  文件 "dyn.py",第 6 行,在 func_to_trace 中
    if x.sum() > 0:
  文件 "pytorch/torch/fx/proxy.py",第 155 行,在 __bool__ 中
    return self.tracer.to_bool(self)
  文件 "pytorch/torch/fx/proxy.py",第 85 行,在 to_bool 中
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if语句的条件依赖于x.sum()的值,而x.sum()的值又依赖于函数输入x的值。由于x可能会发生变化(例如,如果你将一个新的输入张量传递给跟踪函数),这就是动态控制流。回溯会通过你的代码向上追溯,以显示这种情况发生的位置。

静态控制流

另一方面,所谓的静态控制流是受支持的。静态控制流是指在调用过程中其值不会改变的循环或if语句。通常,在PyTorch程序中,这种控制流出现在根据超参数对模型架构进行决策的代码中。作为一个具体的例子:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # 这个if语句被称为静态控制流。
        # 它的条件不依赖于任何输入值
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if-statement if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,并且具有该参数不同值的 MyModule 的不同实例具有不同的代码。这是一种有效的模式,由符号追踪支持。

许多动态控制流的实例在语义上是静态控制流。这些实例可以通过移除对输入值的数据依赖来支持符号追踪,例如通过将值移动到Module属性中,或在符号追踪期间将具体值绑定到参数上:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # 失败!

fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,包含此代码的程序部分可以被跟踪为对方法(参见使用Tracer类自定义跟踪)或函数(参见wrap())的调用,而不是通过它们进行跟踪。

torch函数

FX 使用 __torch_function__ 作为拦截调用的机制(有关此内容的更多信息,请参阅 技术概述)。一些函数,例如内置的 Python 函数或 math 模块中的函数,不受 __torch_function__ 的覆盖,但我们仍然希望在符号追踪中捕获它们。例如:

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    通过批次维度的大小对 `x` 进行归一化
    """
    return x / sqrt(len(x))

# 这是有效的 Python 代码
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  文件 "sqrt.py",第 9 行,在 normalize 中
    return x / sqrt(len(x))
  文件 "pytorch/torch/fx/proxy.py",第 161 行,在 __len__ 中
    引发 RuntimeError("'len' 在符号追踪中默认不支持。如果你想记录这个调用,请在模块作用域中调用 torch.fx.wrap('len')")
RuntimeError: 'len' 在符号追踪中默认不支持。如果你想记录这个调用,请在模块作用域中调用 torch.fx.wrap('len')
"""

错误告诉我们内置函数 len 不受支持。 我们可以使这样的函数在跟踪中作为直接调用记录,使用 wrap() API:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用Tracer类自定义跟踪

The Tracer 类是实现 symbolic_trace 的基础类。可以通过子类化 Tracer 来自定义跟踪行为,如下所示:

class MyCustomTracer(torch.fx.Tracer):
    # 在这里你可以覆盖各种方法
    # 以自定义追踪。请参阅 `Tracer` API
    # 参考
    pass


# 让我们使用这个自定义追踪器来追踪这个模块
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() 返回一个 Graph。让我们将其包装在一个
# GraphModule 中以使其可运行
traced = torch.fx.GraphModule(mod, traced_graph)

叶模块

Leaf Modules 是那些在符号跟踪中作为调用出现而不是被跟踪的模块。默认的 Leaf Modules 集合是标准的 torch.nn 模块实例集合。例如:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` 被保留为一个调用,然而 `submod` 被追踪了。
# 这是因为默认的“叶模块”集合包括了所有
# 标准的 `torch.nn` 模块。
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过重写Tracer.is_leaf_module()来自定义叶子模块的集合。

杂项

  • 张量构造函数(例如 torch.zeros, torch.ones, torch.rand, torch.randn, torch.sparse_coo_tensor) 目前不可追踪。

    • 确定性构造函数(zerosones)可以使用,并且它们生成的值将作为常量嵌入到跟踪中。这只有在这些构造函数的参数引用动态输入大小时才会出现问题。在这种情况下,ones_likezeros_like 可能是可行的替代方案。

    • 非确定性构造函数(randrandn)将在跟踪中嵌入一个随机值。这可能不是预期的行为。一种解决方法是将torch.randn包装在一个torch.fx.wrap函数中,并调用该函数。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能会在未来的版本中修复。

  • 类型注解

    • Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor)是支持的 并且将由符号追踪保留。

    • Python 2风格的注释类型注解 # 类型: (torch.Tensor, int) -> torch.Tensor 目前不支持。

    • 目前不支持在函数内对本地名称进行注解。

  • 训练标志和子模块周围抓住了

    • 当使用像 torch.nn.functional.dropout 这样的函数时,通常会将训练参数作为 self.training 传递进来。在 FX 追踪期间,这可能会被作为常量值固定下来。

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: 张量不接近!
    
    不匹配的元素: 15 / 15 (100.0%)
    最大绝对差异: 1.6207983493804932 在索引 (0, 2) 处 (允许的最大差异为 1e-05)
    最大相对差异: 1.0 在索引 (0, 0) 处 (允许的最大差异为 0.0001)
    """
    
    • 然而,当使用标准的 nn.Dropout() 子模块时,训练标志被封装,并且由于 nn.Module 对象模型的保留,可以进行更改。

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,请考虑将那些与training标志动态交互的模块标记为叶子模块。

API参考

torch.fx.symbolic_trace(root, concrete_args=None)[源代码]

符号追踪 API

给定一个 nn.Module 或函数实例 root,此函数将返回一个通过记录在追踪 root 过程中看到的操作而构建的 GraphModule

concrete_args 允许你部分特化你的函数,无论是为了移除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于控制流的存在的,FX通常无法跟踪此过程。然而,我们可以使用concrete_args来专门处理b的值以进行跟踪:

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

请注意,尽管您仍然可以传入不同的b值,但它们将被忽略。

我们也可以使用concrete_args来消除函数中的数据结构处理。这将使用pytrees来展平您的输入。为了避免过度专门化,对于不应该专门化的值,请传入fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
Parameters
  • (联合[torch.nn.Module, 可调用]) – 要追踪并转换为图表示的模块或函数。

  • concrete_args (可选[字典[str, 任意]]) – 要部分特化的输入

Returns

从记录的操作创建的模块来自 root

Return type

GraphModule

注意

此API的向后兼容性得到保证。

torch.fx.wrap(fn_or_name)[源代码]

此函数可以在模块级范围内调用,以将 fn_or_name 注册为“叶函数”。 “叶函数”将在 FX 跟踪中保留为 CallFunction 节点,而不是被跟踪通过:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # 在符号追踪时,对 my_custom_function 的调用将被插入到
    # 图中,而不是对其进行追踪。
    return my_custom_function(x, y)

这个函数也可以等价地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

一个包装函数可以被视为一个“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在FX跟踪中保留为调用的函数,而不是被跟踪通过的函数。

Parameters

fn_or_name (Union[str, Callable]) – 当调用时插入到图中的函数或全局函数的名称

注意

此API的向后兼容性得到保证。

class torch.fx.GraphModule(*args, **kwargs)[源代码]

GraphModule 是一个由 fx.Graph 生成的 nn.Module。GraphModule 具有一个 graph 属性,以及由该 graph 生成的 codeforward 属性。

警告

graph 被重新赋值时,codeforward 将自动重新生成。然而,如果你编辑了 graph 的内容而没有重新赋值 graph 属性本身,你必须调用 recompile() 来更新生成的代码。

注意

此API的向后兼容性得到保证。

__init__(root, graph, class_name='GraphModule')[源代码]

构建一个GraphModule。

Parameters
  • (联合[torch.nn.Module, 字典[字符串, 任意]) – 可以是 nn.Module 实例或映射字符串到任意属性类型的字典。 如果 是一个模块,图的节点中任何对基于模块的对象的引用(通过限定名称)将在 的模块层次结构中从相应位置复制到 GraphModule 的模块层次结构中。 如果 是一个字典,节点中的限定名称将直接在字典的键中查找。字典映射到的对象将被复制到 GraphModule 的模块层次结构中的适当位置。

  • (Graph) – graph 包含此GraphModule应使用的节点以进行代码生成

  • 类名 (str) – name 表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自 GraphModule。设置此项可能有助于将其设置为 root 的原始名称或在您的转换上下文中具有意义的名称。

注意

此API的向后兼容性得到保证。

add_submodule(target, m)[源代码]

将给定的子模块添加到self

如果它们是 target 的子路径,则在尚不存在的情况下安装空模块。

Parameters
  • 目标 (str) – 新子模块的完全限定字符串名称 (参见 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

  • m (模块) – 子模块本身;我们想要安装在当前模块中的实际对象

Returns

子模块是否可以被插入。对于

此方法返回True,链中的每个对象 由target表示必须要么a) 还不存在, 或者b) 引用一个nn.Module(不是参数或其他属性)

Return type

bool

注意

此API的向后兼容性得到保证。

property code: str

返回从该GraphModule的底层Graph生成的Python代码。

delete_all_unused_submodules()[源代码]

删除self中所有未使用的子模块。

一个模块被认为是“已使用”的,如果以下任一条件为真: 1. 它有被使用的子模块 2. 它的forward通过call_module节点直接调用 3. 它有一个非Module属性,该属性通过get_attr节点被使用

此方法可以被调用来清理一个 nn.Module,而不需要手动调用 delete_submodule 在每个未使用的子模块上。

注意

此API的向后兼容性得到保证。

delete_submodule(target)[源代码]

self 中删除给定的子模块。

如果target不是一个有效的目标,模块将不会被删除。

Parameters

目标 (str) – 新子模块的完全限定字符串名称 (参见 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

Returns

目标字符串是否引用了

我们想要删除的子模块。返回值为 False 表示 target 不是对子模块的有效引用。

Return type

bool

注意

此API的向后兼容性得到保证。

property graph:

返回此 GraphModule 的基础 Graph

print_readable(print_output=True)[源代码]

返回为当前GraphModule及其子GraphModules生成的Python代码

警告

此API是实验性的,并且向后兼容。

recompile()[源代码]

从其 graph 属性重新编译此 GraphModule。在编辑包含的 graph 后应调用此方法,否则此 GraphModule 生成的代码将过时。

注意

此API的向后兼容性得到保证。

Return type

Python代码

to_folder(folder, module_name='FxModule')[源代码]
Dumps out module to folder with module_name so that it can be

导入自 from import

参数:

folder (Union[str, os.PathLike]): 要写入代码的文件夹

module_name (str): Top-level name to use for the Module while

写出代码

警告

此API是实验性的,并且向后兼容。

class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]

Graph 是 FX Intermediate Representation 中使用的主要数据结构。 它由一系列 Node 组成,每个节点代表调用点(或其他语法结构)。这些 Node 的列表共同构成一个有效的 Python 函数。

例如,以下代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下图形:

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

关于在 Graph 中表示的操作的语义,请参见 Node

注意

此API的向后兼容性得到保证。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]

构造一个空的图。

注意

此API的向后兼容性得到保证。

call_function(the_function, args=None, kwargs=None, type_expr=None)[源代码]

插入一个 call_function 节点 中。一个 call_function 节点 表示对一个 Python 可调用对象的调用,由 the_function 指定。

Parameters
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 操作符、Python 函数,或者是 builtinsoperator 命名空间中的成员。

  • args (可选[元组[参数, ...]]) – 要传递给被调用函数的定位参数。

  • kwargs (可选[字典[str, 参数]]) – 传递给被调用函数的键值参数

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

Returns

新创建并插入的 call_function 节点。

Return type

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node()相同。

注意

此API的向后兼容性得到保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[源代码]

Graph中插入一个call_method Node。一个call_method节点 表示对args的第0个元素上的给定方法的调用。

Parameters
  • method_name (str) – 要应用于self参数的方法名称。 例如,如果args[0]是一个表示TensorNode, 那么要在该Tensor上调用relu(),请将relu传递给method_name

  • args (可选[元组[参数, ...]]) – 要传递给被调用方法的位置参数。请注意,这应该包括一个self参数。

  • kwargs (可选[字典[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

Returns

新创建并插入的 call_method 节点。

Return type

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node()相同。

注意

此API的向后兼容性得到保证。

call_module(module_name, args=None, kwargs=None, type_expr=None)[源代码]

Graph中插入一个call_module Node。一个call_module节点 表示对Module层次结构中某个Module的forward()函数的调用。

Parameters
  • module_name (str) – 在 Module 层次结构中要调用的 Module 的限定名称。例如,如果被跟踪的 Module 有一个名为 foo 的子模块,该子模块又有一个名为 bar 的子模块,则应将限定名称 foo.bar 作为 module_name 传递以调用该模块。

  • args (可选[元组[参数, ...]]) – 要传递给被调用方法的位置参数。请注意,这应包括self参数。

  • kwargs (可选[字典[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

Returns

新创建并插入的 call_module 节点。

Return type

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node()相同。

注意

此API的向后兼容性得到保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[源代码]

创建一个节点并将其添加到当前插入点的中。 请注意,当前插入点可以通过Graph.inserting_before()Graph.inserting_after()来设置。

Parameters
  • op (str) – 此节点的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在 Graph 文档字符串中描述。

  • args (可选[元组[参数, ...]]) – 是传递给此节点的参数元组。

  • kwargs (可选[字典[str, Argument]]) – 此节点的kwargs

  • 名称 (可选[字符串]) – 为节点指定的一个可选字符串名称。 这将影响在生成的Python代码中分配的值的名称。

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

Returns

新创建并插入的节点。

Return type

节点

注意

此API的向后兼容性得到保证。

eliminate_dead_code()[源代码]

根据每个节点的用户数量以及节点是否具有副作用,从图中移除所有无效代码。在调用之前,图必须进行拓扑排序。

Returns

图表是否由于传递而发生了变化。

Return type

bool

示例:

在死代码被消除之前,a 来自 a = x + 1 下面没有使用者,因此可以从图中消除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

在死代码被消除后,a = x + 1 已被移除,而其余的 forward 仍然保留。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法来避免移除具有副作用的节点(参见 Node.is_impure),但总体覆盖率非常差,因此除非你确定你的 FX 图完全由函数操作组成,否则你应该假设调用此方法是不可靠的。

注意

此API的向后兼容性得到保证。

erase_node(to_erase)[源代码]

Graph中删除一个Node。如果在Graph中仍有该节点的用户,则会抛出异常。

Parameters

to_erase (节点) – 要从 Graph 中删除的 Node

注意

此API的向后兼容性得到保证。

get_attr(qualified_name, type_expr=None)[源代码]

在图中插入一个get_attr节点。一个get_attr Node表示从Module层次结构中获取一个属性。

Parameters
  • qualified_name (str) – 要检索的属性的完全限定名称。 例如,如果被跟踪的模块有一个名为 foo 的子模块,该子模块有一个名为 bar 的子模块,该子模块有一个名为 baz 的属性,则应将完全限定名称 foo.bar.baz 作为 qualified_name 传递。

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

Returns

新创建并插入的 get_attr 节点。

Return type

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此API的向后兼容性得到保证。

graph_copy(g, val_map, return_output_node=False)[源代码]

将给定图中的所有节点复制到 self 中。

Parameters
  • g () – 要从中复制节点的源图。

  • val_map (Dict[Node, Node]) – 一个将被填充的映射字典,从 g 中的节点映射到 self 中的节点。注意,val_map 可以预先传入值以覆盖某些值的复制。

Returns

如果 g 有一个 output 节点,则 self 中的值现在等同于 g 中的输出值。否则为 None

Return type

可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]

注意

此API的向后兼容性得到保证。

inserting_after(n=None)[源代码]
Set the point at which create_node and companion methods will insert into the graph.

当在‘with’语句中使用时,这将临时设置插入点,并在with语句退出时恢复它:

with g.inserting_after(n):
    ... # 在节点n之后插入
... # 插入点恢复到之前的状态
g.inserting_after(n) # 永久设置插入点

参数:

n (Optional[Node]): The node before which to insert. If None this will insert after

整个图的开始。

Returns:

一个资源管理器,将在 __exit__ 时恢复插入点。

注意

此API的向后兼容性得到保证。

inserting_before(n=None)[源代码]
Set the point at which create_node and companion methods will insert into the graph.

当在‘with’语句中使用时,这将临时设置插入点,并在with语句退出时恢复它:

with g.inserting_before(n):
    ... # 在节点 n 之前插入
... # 插入点恢复到之前的状态
g.inserting_before(n) # 永久设置插入点

参数:

n (Optional[Node]): The node before which to insert. If None this will insert before

整个图的开始。

Returns:

一个资源管理器,将在 __exit__ 时恢复插入点。

注意

此API的向后兼容性得到保证。

lint()[源代码]

对当前图进行各种检查,以确保其结构良好。特别是: - 检查节点是否具有正确的所有权(归属于此图) - 检查节点是否按拓扑顺序排列 - 如果此图归属于一个GraphModule,检查目标是否存在于该GraphModule中

注意

此API的向后兼容性得到保证。

node_copy(node, arg_transform=<function Graph.<lambda>>)[源代码]

将一个节点从一个图复制到另一个图中。arg_transform 需要将参数从节点所在的图转换到当前图。示例:

# 将 `g` 中的所有节点复制到 `new_graph` 中
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
Parameters
  • 节点 (节点) – 要复制到 self 的节点。

  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将节点中的Node参数在argskwargs中转换为self中的等效参数。在最简单的情况下,这应该从原始图中将节点映射到self的表中检索值。

Return type

节点

注意

此API的向后兼容性得到保证。

property nodes: _node_list

获取构成此图的节点列表。

请注意,这个节点列表表示是一个双向链表。在迭代期间进行变异(例如删除一个节点,添加一个节点)是安全的。

Returns

一个双向链表的节点。请注意,可以在该列表上调用 reversed 来切换迭代顺序。

on_generate_code(make_transformer)[源代码]

在生成Python代码时注册一个转换器函数

Args:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):

一个返回代码转换器的函数,该函数由on_generate_code调用以获取代码转换器。

此函数还会将其输入作为当前注册的代码转换器(如果没有注册则为None),以防不希望覆盖它。这对于将代码转换器链接在一起非常有用。

Returns:

一个上下文管理器,当在with语句中使用时,自动恢复之前注册的代码转换器。

示例:

gm: fx.GraphModule = ...

# 这是一个我们想要注册的代码转换器。这个代码转换器在生成的torch.fx代码的最前面插入一个pdb导入和跟踪语句,以便使用PDB库进行手动调试。
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]

# 注册`insert_pdb`,并覆盖当前注册的代码转换器(由lambda中的`_`给出):
gm.graph.on_generate_code(
    lambda _: insert_pdb
)

# 或者,注册一个代码转换器,该转换器首先通过现有的注册转换器运行`body`,然后通过`insert_pdb`运行:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans
            else body
        )
    )
)

gm.recompile()
gm(*inputs)  # 进入pdb

此函数也可以用作上下文管理器,其好处是自动恢复之前注册的代码转换器:

# ... 继续自上一个示例

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # 对 `gm` 进行更多操作...
    gm.recompile()
    gm(*inputs)  # 进入 pdb

# 现在恢复了之前的代码转换器(但带有 pdb 的 `gm` 代码仍然存在 - 这意味着你可以在这里使用 pdb 运行 `gm`,直到你
# 运行下一个 `recompile()`)。

警告

此API是实验性的,并且向后兼容。

output(result, type_expr=None)[源代码]

插入一个output NodeGraph中。一个output节点代表Python代码中的一个return语句。result是应该返回的值。

Parameters
  • 结果 (参数) – 要返回的值。

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此API的向后兼容性得到保证。

placeholder(name, type_expr=None, default_value)[源代码]

在图中插入一个占位符节点。一个占位符表示一个函数输入。

Parameters
  • 名称 (字符串) – 输入值的名称。这对应于此Graph表示的函数的定位参数的名称。

  • type_expr (可选[任意]) – 一个可选的类型注解,表示此节点的输出将具有的Python类型。在某些情况下,这是正确代码生成所必需的(例如,当函数随后用于TorchScript编译时)。

  • default_value (任意) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定该参数没有默认值。

Return type

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此API的向后兼容性得到保证。

print_tabular()[源代码]

以表格格式打印图的中间表示。请注意,此API需要安装tabulate模块。

注意

此API的向后兼容性得到保证。

process_inputs(*args)[源代码]

处理参数以便它们可以传递给FX图。

警告

此API是实验性的,并且向后兼容。

process_outputs(out)[源代码]

警告

此API是实验性的,并且向后兼容。

python_code(root_module, *, verbose=False)[源代码]

将这个 Graph 转换为有效的 Python 代码。

Parameters

root_module (str) – 要查找限定名称目标的根模块名称。这通常是‘self’。

Returns

src: 表示对象的Python源代码 globals: 全局名称的字典,在src中 -> 它们引用的对象。

Return type

一个PythonCode对象,由两个字段组成

注意

此API的向后兼容性得到保证。

set_codegen(codegen)[源代码]

警告

此API是实验性的,并且向后兼容。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[源代码]

Node 是表示 Graph 中各个操作的数据结构。在大多数情况下,节点表示对各种实体的调用站点,例如操作符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个 Node 都有一个由其 op 属性指定的函数。每个 op 值的 Node 语义如下:

  • placeholder 表示一个函数输入。name 属性指定此值将采用的名称。 target 同样表示参数的名称。args 包含以下之一:1) 无,或 2) 表示函数输入默认参数的单个参数。kwargs 不关心。占位符对应于图输出中的函数参数(例如 x)。

  • get_attr 从模块层次结构中检索一个参数。name 是同样地,获取结果被分配到的名称。target 是参数在模块层次结构中的完全限定名称。 argskwargs 是无关紧要的

  • call_function 将一个自由函数应用于某些值。name 同样是要赋值的值的名称。target 是要应用的函数。argskwargs 表示函数的参数,遵循 Python 调用约定。

  • call_module 在模块层次结构的 forward() 方法中应用一个模块到给定的参数。name 与之前相同。target 是模块层次结构中要调用的模块的完全限定名称。argskwargs 表示调用模块的参数,不包括 self 参数

  • call_method 调用一个值上的方法。name 类似。target 是要应用于 self 参数的方法的字符串名称。argskwargs 表示要调用模块的参数,包括 self 参数

  • output 包含 traced 函数在其 args[0] 属性中的输出。这对应于 Graph 打印输出中的“return”语句。

注意

此API的向后兼容性得到保证。

property all_input_nodes: List[节点]

返回所有作为此节点输入的节点。这相当于迭代 argskwargs,并且只收集那些是节点的值。

Returns

在此 Nodeargskwargs 中出现的 Nodes 列表,按该顺序排列。

append(x)[源代码]

在图的节点列表中,在此节点后插入 x。 等同于 self.next.prepend(x)

Parameters

x (节点) – 要放在此节点之后的节点。必须是同一图的成员。

注意

此API的向后兼容性得到保证。

property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, 节点, str, int, float, bool, complex, 数据类型, 张量, 设备, 内存格式, 布局, OpOverload]], ...]

这个 Node 的参数元组。参数的解释取决于节点的操作码。更多信息请参阅 Node 文档字符串。

允许对此属性进行赋值。在赋值时,所有使用和用户的记录都会自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)[源码]

返回一个描述性的字符串表示形式。

此方法可以不带参数使用,作为调试工具。

此函数也在 Graph__str__ 方法中内部使用。placeholder_namesmaybe_return_typename 中的字符串共同构成了此图所包围的 GraphModule 中自动生成的 forward 函数的签名。placeholder_namesmaybe_return_typename 不应在其他情况下使用。

Parameters
  • placeholder_names (可选[列表[字符串]]) – 一个将存储格式化字符串的列表,表示生成的forward函数中的占位符。仅限内部使用。

  • maybe_return_typename (可选[列表[字符串]]) – 一个单元素列表,将存储 表示生成的forward函数输出的格式化字符串。仅限内部使用。

Returns

如果1) 我们使用 format_node 作为内部辅助函数

Graph__str__ 方法中,并且 2) self 是一个占位符节点,返回 None。否则,返回当前节点的描述性字符串表示。

Return type

str

注意

此API的向后兼容性得到保证。

insert_arg(idx, arg)[源代码]

在参数列表中插入一个带有指定索引的位置参数。

Parameters
  • idx (int) – 在 self.args 中要插入元素的索引。

  • arg (参数) – 要插入到 args 中的新参数值

注意

此API的向后兼容性得到保证。

is_impure()[源代码]

返回此操作是否为不纯操作,即如果其操作是占位符或输出,或者如果调用的是不纯的call_function或call_module。

Returns

如果操作是纯操作或非纯操作。

Return type

bool

警告

此API是实验性的,并且向后兼容。

property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, 节点, str, int, float, bool, complex, 数据类型, 张量, 设备, 内存格式, 布局, OpOverload]]]

传递给此节点的关键字参数字典。参数的解释取决于节点的操作码。更多信息请参阅节点文档字符串。

允许对此属性进行赋值。在赋值时,所有使用和用户的记录都会自动更新。

property next: 节点

返回链表中下一个 Node

Returns

链表中下一个Node节点。

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[源代码]

返回标准化后的参数给Python目标。这意味着 args/kwargs 将被匹配到模块/函数的签名,并且如果 normalize_to_only_use_kwargs 为真,则按位置顺序返回仅包含kwargs的参数。 同时填充默认值。不支持仅位置参数或可变参数。

支持模块调用。

可能需要 arg_typeskwarg_types 以消除重载的歧义。

Parameters
  • (torch.nn.Module) – 要解析模块目标的模块。

  • arg_types (可选[元组[任意]]) – 参数的参数类型元组

  • kwarg_types (可选[字典[str, 任意]]) – 关键字参数的类型字典

  • normalize_to_only_use_kwargs (布尔值) – 是否规范化为仅使用关键字参数。

Returns

返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None

Return type

可选[ArgsKwargsPair]

警告

此API是实验性的,并且向后兼容。

prepend(x)[源代码]

在图的节点列表中,在此节点之前插入 x。示例:

之前: p -> self
        bx -> x -> ax
之后:  p -> x -> self
        bx -> ax
Parameters

x (节点) – 要放在此节点之前的节点。必须是同一图的成员。

注意

此API的向后兼容性得到保证。

property prev: 节点

返回链表中前一个节点

Returns

链表中前一个节点

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[源代码]

将Graph中所有使用的self替换为Node replace_with

Parameters
  • replace_with (节点) – 用于替换所有self使用的节点。

  • delete_user_cb (可调用对象) – 用于确定是否应删除自节点中给定用户的回调函数。

  • propagate_meta (bool) – 是否将原始节点上的.meta字段的所有属性复制到替换节点上。为了安全起见,仅当替换节点尚未存在.meta字段时,此操作才有效。

Returns

此更改所涉及的节点列表。

Return type

列表[节点]

注意

此API的向后兼容性得到保证。

replace_input_with(old_input, new_input)[源代码]

遍历 self 的输入节点,并将所有 old_input 的实例替换为 new_input

Parameters
  • old_input (节点) – 要替换的旧输入节点。

  • new_input (节点) – 用于替换 old_input 的新输入节点。

注意

此API的向后兼容性得到保证。

property stack_trace: Optional[str]

返回在追踪过程中记录的Python堆栈跟踪,如果有的话。 当使用fx.Tracer进行追踪时,此属性通常由Tracer.create_proxy填充。为了在调试目的下记录追踪期间的堆栈跟踪,请在Tracer实例上设置record_stack_traces = True。 当使用dynamo进行追踪时,此属性将默认由OutputGraph.create_proxy填充。

stack_trace 将在字符串的末尾包含最内层的帧。

update_arg(idx, arg)[源代码]

更新现有的位置参数以包含新值 arg。调用后,self.args[idx] == arg

Parameters
  • idx (int) – 要更新的元素在 self.args 中的索引

  • arg (参数) – 要写入 args 的新参数值

注意

此API的向后兼容性得到保证。

update_kwarg(key, arg)[源代码]

更新现有的关键字参数以包含新值 arg。调用后,self.kwargs[key] == arg

Parameters
  • (字符串) – 要更新的元素在 self.kwargs 中的键

  • arg (参数) – 要写入 kwargs 的新参数值

注意

此API的向后兼容性得到保证。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源代码]

Tracer 是实现 torch.fx.symbolic_trace 符号追踪功能的类。调用 symbolic_trace(m) 等同于调用 Tracer().trace(m)

Tracer 可以被继承以覆盖跟踪过程中各种行为。可以覆盖的不同行为在该类的各个方法的文档字符串中进行了描述。

注意

此API的向后兼容性得到保证。

call_module(m, forward, args, kwargs)[源代码]

指定此Tracer在遇到对nn.Module实例的调用时的行为的方法。

默认情况下,行为是检查被调用的模块是否为叶模块,通过 is_leaf_module。如果是,则发出一个引用 mcall_module 节点在 Graph 中。否则,正常调用 Module,跟踪其 forward 函数中的操作。

此方法可以被重写以——例如——创建嵌套的跟踪GraphModules,或您希望在跨越Module边界时所需的任何其他行为。

Parameters
  • m (模块) – 正在发出调用的模块

  • forward (可调用对象) – 要调用的 Module 的 forward() 方法

  • args (元组) – 模块调用点的参数

  • kwargs (字典) – 模块调用点的kwargs

Returns

Module 调用的返回值。如果发出了一个 call_module 节点,则这是一个 Proxy 值。否则,它是从 Module 调用返回的任何值。

Return type

任意

注意

此API的向后兼容性得到保证。

create_arg(a)[源代码]

指定在准备值以用作 Graph 中节点的参数时,跟踪行为的指定方法。

默认情况下,行为包括:

  1. 遍历集合类型(例如元组、列表、字典)并对元素递归调用 create_args

  2. 给定一个代理对象,返回对底层 IR Node 的引用

  3. 给定一个非代理张量对象,为各种情况生成IR:

    • 对于一个参数,生成一个引用该参数的 get_attr 节点

    • 对于非参数张量,将张量存储在一个特殊属性中,该属性引用该属性。

此方法可以被重写以支持更多类型。

Parameters

a (任意) – 作为 ArgumentGraph 中发出的值。

Returns

将值 a 转换为适当的 Argument

Return type

可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]

注意

此API的向后兼容性得到保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[源代码]

创建与root模块签名相对应的placeholder节点。此方法内省root的签名并相应地发出这些节点,同时支持*args**kwargs

警告

此API是实验性的,并且向后兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)

插入一个图节点,给定目标、参数、关键字参数和名称。

此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录就地操作。

注意

此API的向后兼容性得到保证。

Return type

节点

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

从给定的参数创建一个节点,然后返回包装在代理对象中的节点。

如果 kind = ‘placeholder’,那么我们正在创建一个表示函数参数的节点。如果我们需要编码一个默认参数,我们使用 args 元组。对于 placeholder 节点,args 通常是空的。

注意

此API的向后兼容性得到保证。

getattr(attr, attr_val, parameter_proxy_cache)[源代码]

指定在调用 nn.Module 实例时,当我们对 Tracer 调用 getattr 时的行为的函数。

默认情况下,行为是返回属性的代理值。它还会将代理值存储在parameter_proxy_cache中,以便将来的调用将重用该代理,而不是创建一个新的代理。

此方法可以被重写以——例如——在查询参数时不返回代理。

Parameters
  • attr (str) – 查询的属性名称

  • attr_val (任意) – 属性的值

  • parameter_proxy_cache (字典[str, 任意]) – attr 名称到代理的缓存

Returns

getattr 调用的返回值。

警告

此API是实验性的,并且向后兼容。

is_leaf_module(m, module_qualified_name)[源代码]

指定给定的 nn.Module 是否为“叶子”模块的方法。

叶子模块是出现在IR中的原子单元,通过call_module调用引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块是叶子模块。所有其他模块都会被追踪,并且它们的组成操作会被记录,除非通过此参数另行指定。

Parameters
  • m (模块) – 被查询的模块

  • module_qualified_name (str) – 该模块到根的路径。例如, 如果你有一个模块层次结构,其中子模块 foo 包含 子模块 bar,其中包含子模块 baz,该模块将 在此处显示为限定名称 foo.bar.baz

Return type

bool

注意

此API的向后兼容性得到保证。

iter(obj)
Called when a proxy object is being iterated over, such as

当用于控制流程时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。

注意

此API的向后兼容性得到保证。

Return type

迭代器

keys(obj)
Called when a proxy object is has the keys() method called.

这是在代理上调用 ** 时发生的情况。如果 ** 应该在您的自定义跟踪器中工作,则应返回一个迭代器。

注意

此API的向后兼容性得到保证。

Return type

任意

path_of_module(mod)[源代码]

用于在 root 的模块层次结构中查找 mod 的限定名称的辅助方法。例如,如果 root 有一个名为 foo 的子模块,该子模块又有一个名为 bar 的子模块,将 bar 传递给此函数将返回字符串 “foo.bar”。

Parameters

mod (str) – 要检索限定名称的模块

Return type

str

注意

此API的向后兼容性得到保证。

proxy(node)

注意

此API的向后兼容性得到保证。

Return type

代理

to_bool(obj)
Called when a proxy object is being converted to a boolean, such as

当用于控制流程时。通常我们不知道该做什么,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。

注意

此API的向后兼容性得到保证。

Return type

bool

trace(root, concrete_args=None)[源代码]

跟踪 root 并返回相应的 FX Graph 表示。root 可以是 nn.Module 实例或 Python 可调用对象。

请注意,在此调用之后,self.root 可能与传入的 root 不同。例如,当一个自由函数传递给 trace() 时,我们将创建一个 nn.Module 实例作为根,并添加嵌入的常量。

Parameters
  • (联合[模块, 可调用]) – 可以是 模块 或要通过其追踪的函数。此参数的向后兼容性得到保证。

  • concrete_args (可选[字典[str, 任意]]) – 不应被视为代理的具体参数。此参数是实验性的,其向后兼容性保证。

Returns

一个表示传入的 root 语义的 Graph

Return type

注意

此API的向后兼容性得到保证。

class torch.fx.Proxy(node, tracer=None)[源代码]

Proxy 对象是 Node 包装器,它们在符号追踪过程中流经程序,并记录所有操作(torch 函数调用、方法调用、运算符)到不断增长的 FX 图中。

如果你正在进行图变换,你可以将你自己的 Proxy 方法包装在一个原始的 Node 周围,这样你就可以使用重载的运算符向 Graph 添加额外的东西。

Proxy 对象不能被迭代。换句话说,如果在一个循环中使用 Proxy 或在函数参数中使用 *args/**kwargs,符号追踪器将会抛出错误。

有两种主要的方法来解决这个问题: 1. 将不可追踪的逻辑提取到一个顶层函数中,并使用fx.wrap对其进行包装。 2. 如果控制流是静态的(即循环次数基于某些超参数),代码可以保留在原位并重构为如下形式:

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关Proxy内部更详细的描述,请查看torch/fx/OVERVIEW.md中的“Proxy”部分

注意

此API的向后兼容性得到保证。

class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源代码]

解释器逐节点执行FX图。这种模式可以用于许多方面,包括编写代码转换以及分析过程。

Interpreter类中的方法可以被重写以自定义执行行为。以下是可重写方法的调用层次结构图:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们想要将所有实例的 torch.neg 替换为 torch.sigmoid 并且反之亦然(包括它们的 Tensor 方法等效项)。我们可以像这样子类化 Interpreter:

```python class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid()) ```
Parameters
  • 模块 (torch.nn.Module) – 要执行的模块

  • garbage_collect_values (bool) – 是否在模块执行过程中删除值在其最后一次使用后。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看Interpreter.env属性来检查执行中的所有中间值。

  • graph (可选[Graph]) – 如果传递,解释器将执行此图而不是 module.graph,使用提供的 module 参数来满足任何状态请求。

注意

此API的向后兼容性得到保证。

boxed_run(args_list)[源代码]

通过解释运行模块并返回结果。 这使用了“装箱”调用约定,其中您传递一个参数列表,这些参数将由解释器清除。 这确保了输入张量被及时释放。

注意

此API的向后兼容性得到保证。

call_function(target, args, kwargs)[源代码]

执行一个call_function节点并返回结果。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 本次调用的关键字参数字典

Return type

任意

Return

任意:函数调用返回的值

注意

此API的向后兼容性得到保证。

call_method(target, args, kwargs)[源代码]

执行一个 call_method 节点并返回结果。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Return type

任意

Return

任意:方法调用返回的值

注意

此API的向后兼容性得到保证。

call_module(target, args, kwargs)[源代码]

执行一个 call_module 节点并返回结果。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Return type

任意

Return

任意: 模块调用返回的值

注意

此API的向后兼容性得到保证。

fetch_args_kwargs_from_env(n)[源代码]

从当前执行环境中获取节点 nargskwargs 的具体值。

Parameters

n (节点) – 要获取argskwargs的节点。

Returns

argskwargs 使用具体的值作为 n

Return type

元组[元组, 字典]

注意

此API的向后兼容性得到保证。

fetch_attr(target)[源代码]

self.moduleModule 层次结构中获取一个属性。

Parameters

目标 (str) – 要获取的属性的完全限定名称

Returns

属性的值。

Return type

任意

注意

此API的向后兼容性得到保证。

get_attr(target, args, kwargs)[源代码]

执行一个 get_attr 节点。将从 self.moduleModule 层次结构中检索属性值。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Returns

被检索到的属性的值

Return type

任意

注意

此API的向后兼容性得到保证。

map_nodes_to_values(args, n)[源代码]

递归地遍历 args 并在当前执行环境中查找每个 Node 的具体值。

Parameters
  • args (参数) – 在其中查找具体值的数据结构

  • n (节点) – 参数 args 所属的节点。这仅用于错误报告。

Return type

可选[联合[元组[任意, …], 列表[任意], 字典[字符串, 任意], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载]]

注意

此API的向后兼容性得到保证。

output(target, args, kwargs)[源代码]

执行一个 output 节点。这实际上只是检索由 output 节点引用的值并返回它。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Returns

输出节点引用的返回值

Return type

任意

注意

此API的向后兼容性得到保证。

placeholder(target, args, kwargs)[源代码]

执行一个placeholder节点。请注意,这是有状态的: Interpreter维护了一个内部迭代器,用于遍历传递给run的参数,并且此方法返回该迭代器的next()。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Returns

获取到的参数值。

Return type

任意

注意

此API的向后兼容性得到保证。

run(*args, initial_env=None, enable_io_processing=True)[源代码]

通过解释运行模块并返回结果。

Parameters
  • *args – 要传递给模块的参数,按位置顺序排列

  • initial_env (可选[字典[Node, 任意值]]) – 一个可选的执行初始环境。 这是一个将Node映射到任意值的字典。例如,可以使用它来预填充某些Nodes的结果,以便在解释器中仅进行部分评估。

  • enable_io_processing (bool) – 如果为真,我们首先使用图的process_inputs和process_outputs函数处理输入和输出,然后再使用它们。

Returns

执行模块后返回的值

Return type

任意

注意

此API的向后兼容性得到保证。

run_node(n)[源代码]

运行特定节点 n 并返回结果。 调用占位符、获取属性、调用函数、 调用方法、调用模块或输出,具体取决于 node.op

Parameters

n (节点) – 要执行的节点

Returns

执行 n 的结果

Return type

任意

注意

此API的向后兼容性得到保证。

class torch.fx.Transformer(module)[源代码]

Transformer 是一种特殊类型的解释器,它生成一个新的 Module。它公开了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 不需要像 Interpreter 那样运行参数。Transformer 完全以符号方式工作。

示例

假设我们想要将所有实例的 torch.neg 替换为 torch.sigmoid 以及反之亦然(包括它们的 Tensor 方法等价物)。我们可以像这样子类化 Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
Parameters

模块 (GraphModule) – 要被转换的模块

注意

此API的向后兼容性得到保证。

call_function(target, args, kwargs)[源代码]

注意

此API的向后兼容性得到保证。

Return type

任意

call_module(target, args, kwargs)[源代码]

注意

此API的向后兼容性得到保证。

Return type

任意

get_attr(target, args, kwargs)[源代码]

执行一个 get_attr 节点。在 Transformer 中,这被重写以在输出图中插入一个新的 get_attr 节点。

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Return type

代理

注意

此API的向后兼容性得到保证。

placeholder(target, args, kwargs)[源代码]

执行一个占位符节点。在Transformer中,这被重写以在输出图中插入一个新的占位符

Parameters
  • 目标 (目标) – 此节点的调用目标。有关语义的详细信息,请参阅 Node

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此次调用的关键字参数字典

Return type

代理

注意

此API的向后兼容性得到保证。

transform()[源代码]

转换 self.module 并返回转换后的 GraphModule

注意

此API的向后兼容性得到保证。

Return type

图模块

torch.fx.replace_pattern(gm, pattern, replacement)[源代码]

匹配图模块(gm)图中所有可能的非重叠操作符及其数据依赖集(pattern),然后使用另一个子图(replacement)替换每个匹配的子图。

Parameters
Returns

表示在原始图中匹配到pattern的位置的Match对象列表。如果没有匹配项,则该列表为空。Match定义为:

class Match(NamedTuple):
    # 匹配找到的节点
    anchor: Node
    # 将模式子图中的节点映射到较大图中的节点
    nodes_map: Dict[Node, Node]

Return type

列表[匹配]

示例:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先在 traced_moduleforward 方法中匹配 pattern。模式匹配是基于使用-定义关系进行的,而不是基于节点名称。例如,如果你在 pattern 中有 p = torch.cat([a, b]),你可以匹配原始 forward 函数中的 m = torch.cat([a, b]),尽管变量名称不同(pm)。

pattern中的return语句仅根据其值进行匹配;它可能与较大图中的return语句匹配,也可能不匹配。换句话说,模式不必延伸到较大图的末尾。

当模式匹配时,它将从较大的函数中移除,并替换为replacement。如果在较大的函数中存在多个pattern的匹配项,每个不重叠的匹配项都将被替换。在匹配重叠的情况下,首先找到的重叠匹配项将被替换。(这里的“首先”是根据节点使用-定义关系的拓扑排序定义的。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的内容。)

需要注意的一个重要事项是,pattern 可调用对象的参数必须在可调用对象本身中使用,并且 replacement 可调用对象的参数必须与模式匹配。第一个规则就是为什么在上面的代码块中,forward 函数有参数 x, w1, w2,但 pattern 函数只有参数 w1, w2pattern 不使用 x,所以它不应该将 x 指定为参数。作为第二个规则的示例,考虑替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同的参数数量(xy),即使参数 yreplacement 中未被使用。

在调用 subgraph_rewriter.replace_pattern 之后,生成的 Python 代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

此API的向后兼容性得到保证。

优云智算