Shortcuts

torch.export IR 规范

Export IR 是编译器的一种中间表示(IR),与 MLIR 和 TorchScript 有相似之处。它专门设计用于表达 PyTorch 程序的语义。Export IR 主要通过简化的操作列表来表示计算,对动态性(如控制流)的支持有限。

要创建一个导出IR图,可以使用一个前端,通过跟踪专门化机制来可靠地捕获PyTorch程序。生成的导出IR可以由后端进行优化和执行。目前可以通过torch.export.export()来实现。

本文档将涵盖的关键概念包括:

  • ExportedProgram: 包含导出IR程序的数据结构

  • 图:由一组节点组成。

  • 节点:表示在此节点上存储的操作、控制流和元数据。

  • 值由节点生成和消费。

  • 类型与值和节点相关联。

  • 值的大小和内存布局也已定义。

假设

本文档假设读者对 PyTorch 有足够的了解,特别是对 torch.fx 及其相关工具。因此,本文档将不再描述 torch.fx 文档和论文中已有的内容。

什么是导出IR

导出IR是基于图的PyTorch程序的中间表示IR。 导出IR是在torch.fx.Graph之上实现的。换句话说,所有导出IR图也是有效的FX图,并且如果使用标准的FX语义进行解释,导出IR可以被合理地解释。一个推论是,导出的图可以通过标准的FX代码生成转换为有效的Python程序。

本文档将主要关注Export IR在严格性方面与FX的不同之处,同时跳过与FX相似的部分。

导出的程序

顶级导出IR结构是一个torch.export.ExportedProgram 类。它将PyTorch模型的计算图(通常是一个 torch.nn.Module)与该模型所使用的参数或权重捆绑在一起。

一些值得注意的torch.export.ExportedProgram类的属性包括:

  • graph_module (torch.fx.GraphModule): 包含PyTorch模型的扁平化计算图的数据结构。可以通过ExportedProgram.graph直接访问该图。

  • graph_signature (torch.export.ExportGraphSignature): 图的签名,指定了在图中使用和修改的参数和缓冲区名称。参数和缓冲区不是作为图的属性存储,而是作为图的输入提升。graph_signature 用于跟踪这些参数和缓冲区的附加信息。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): 包含参数和缓冲区的数据结构。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]): 对于导出具有数据依赖行为的程序,每个节点的元数据将包含符号形状(例如 s0i0)。此属性将符号形状映射到它们的上下限范围。

导出 IR 图是一个以 DAG(有向无环图)形式表示的 PyTorch 程序。该图中的每个节点代表一个特定的计算或操作,而该图的边由节点之间的引用组成。

我们可以查看具有此模式的图表:

class Graph:
  nodes: List[Node]

在实践中,Export IR 的图被实现为 torch.fx.Graph Python 类。

导出的IR图包含以下节点(节点将在下一节中详细描述):

  • 0 个或更多类型为 placeholder 的节点

  • 0 个或更多操作类型为 call_function 的节点

  • 恰好1个操作类型为output的节点

推论:最小的有效图将由一个节点组成。即节点永远不会为空。

定义: 图的占位符节点集合表示输入 GraphModule图的输出节点表示输出 GraphModule图的。

示例:

from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

mod = torch.export.export(MyModule())
print(mod.graph)

上述是图的文本表示,每行代表一个节点。

节点

一个节点表示特定的计算或操作,并在 Python 中使用 torch.fx.Node 类表示。节点之间的边通过节点类的 args 属性表示为对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如操作符调用、占位符(即输入)、条件语句和循环。

节点具有以下模式:

class Node:
  name: str # 节点名称
  op_name: str  # 操作类型

  # 以下字段的解释取决于 op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文本格式

如上例所示,注意每行都有这种格式:

%:[...] = [target=](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以紧凑的形式捕获了Node类中除meta之外的所有内容。

具体来说:

  • 是节点名称,如在 node.name 中所示。

  • node.op 字段,必须是以下之一: , , , 或

  • 是节点的目标,如 node.target。此字段的含义取决于 op_name

  • args1, … args 4… 是在 node.args 元组中列出的内容。如果列表中的值是一个 torch.fx.Node,那么它将特别用前导 % 表示。

例如,调用加法运算符的操作将显示为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是另外两个名为 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor 表示实际存储在目标字段中的可调用对象,而不仅仅是它的字符串名称。

这种文本格式的最后一行是:

返回 [add]

这是一个节点,其 op_name = output,表示我们正在返回这个元素。

调用函数

一个 call_function 节点表示对一个操作符的调用。

定义

  • 功能性:如果一个可调用对象满足以下所有要求,我们称其为“功能性的”:

    • 非突变:该操作符不会改变其输入的值(对于张量,这包括元数据和数据)。

    • 无副作用:该操作符不会改变从外部可见的状态,例如更改模块参数的值。

  • 操作符: 是一个具有预定义模式的可调用函数。此类操作符的示例包括功能性的ATen操作符。

在FX中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与原生FX call_function的区别

  1. 在FX图表中,call_function可以引用任何可调用的对象,在Export IR中,我们将其限制为仅限于ATen操作符、自定义操作符和控制流操作符的一个选择子集。

  2. 在导出IR中,常量参数将被嵌入到图中。

  3. 在FX图中,一个get_attr节点可以表示读取图模块中存储的任何属性。然而,在导出IR中,这被限制为仅读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。

元数据

Node.meta 是附加到每个 FX 节点的字典。然而,FX 规范并没有指定元数据可以或将会是什么。导出 IR 提供了更强的契约,特别是所有 call_function 节点将保证具有且仅具有以下元数据字段:

  • node.meta["stack_trace"] 是一个包含Python堆栈跟踪的字符串,引用原始Python源代码。一个堆栈跟踪的示例如下:

    文件 "my_module.py",  19,  forward
    返回 x + dummy_helper(y)
    文件 "helper_utility.py",  89,  dummy_helper
    返回 y + 1
    
  • node.meta["val"] 描述了运行操作的输出。它可以是 类型 , , 一个 List[Union[FakeTensor, SymInt]], 或 None

  • node.meta["nn_module_stack"] 描述了节点来源的 torch.nn.Module 的“堆栈跟踪”,如果它来自 torch.nn.Module 调用。例如,如果一个包含 addmm 操作的节点从 torch.nn.Linear 模块内部调用,而该模块又位于 torch.nn.Sequential 模块中,则 nn_module_stack 将看起来像这样:

    {'self_linear': ('self.linear', ), 'self_sequential': ('self.sequential', )}
    
  • node.meta["source_fn_stack"] 包含在分解之前调用此节点的 torch 函数或叶子 torch.nn.Module 类。 例如,一个包含来自 torch.nn.Linear 模块调用的 addmm 操作的节点将包含 torch.nn.Linear 在其 source_fn 中,而包含来自 torch.nn.functional.Linear 模块调用的 addmm 操作的节点将包含 torch.nn.functional.Linear 在其 source_fn 中。

占位符

占位符表示图的一个输入。它的语义与FX中的完全相同。 占位符节点必须是图的节点列表中的前N个节点。N可以为零。

在FX中的表示

%name = placeholder[target = name](args = ())

目标字段是一个字符串,表示输入的名称。

args,如果不是空的,应该是大小为1,表示此输入的默认值。

元数据

占位符节点也具有 meta[‘val’],类似于 call_function 节点。在这种情况下,val 字段表示图期望为此输入参数接收的输入形状/数据类型。

输出

输出调用表示函数中的返回语句;它因此终止当前图。只有一个输出节点,并且它将始终是图的最后一个节点。

在FX中的表示

output[](args = (%something, …))

这与 torch.fx 中的语义完全相同。args 表示要返回的节点。

元数据

输出节点具有与call_function节点相同的元数据。

获取属性

get_attr 节点表示从封装的 torch.fx.GraphModule 中读取子模块。与来自 torch.fx.symbolic_trace() 的普通 FX 图不同,在普通 FX 图中,get_attr 节点用于读取属性,例如从顶层 torch.fx.GraphModule 读取参数和缓冲区,参数和缓冲区作为输入传递给图模块,并存储在顶层 torch.export.ExportedProgram 中。

在FX中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

图表:

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

该行,%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0],读取包含sin操作符的子模块true_graph_0

参考资料

SymInt

SymInt 是一个对象,可以是字面整数或表示整数的符号(在 Python 中由 sympy.Symbol 类表示)。当 SymInt 是符号时,它描述了一个在编译时对图来说是未知的整数类型变量,即其值仅在运行时已知。

FakeTensor

FakeTensor 是一个包含张量元数据的对象。它可以被视为具有以下元数据。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # 这还不存在

FakeTensor 的 size 字段是一个整数或 SymInts 的列表。如果存在 SymInts,这意味着该张量具有动态形状。如果存在整数,则假定该张量将具有该确切的静态形状。TensorMeta 的秩永远不会是动态的。dtype 字段表示该节点的输出的数据类型。在 Edge IR 中没有隐式类型提升。FakeTensor 中没有步幅。

换句话说:

  • 如果节点中的操作符返回一个张量,那么 node.meta['val'] 是一个描述该张量的 FakeTensor。

  • 如果节点中的操作符返回一个包含 Tensors 的 n 元组,那么 node.meta['val'] 是一个包含 FakeTensors 的 n 元组,用于描述每个张量。

  • 如果节点中的操作符返回一个在编译时已知的整数/浮点数/标量,那么node.meta['val']为None。

  • 如果节点中的操作符返回一个在编译时未知的整数/浮点数/标量,那么node.meta['val']的类型是SymInt。

例如:

  • aten::add 返回一个张量;因此,它的规格将是一个具有该操作符返回的张量的数据类型和大小的FakeTensor。

  • aten::sym_size 返回一个整数;因此它的值将是一个 SymInt,因为它的值只能在运行时获得。

  • max_pool2d_with_indexes 返回一个包含两个张量的元组(Tensor, Tensor);因此,规范也将是一个包含两个FakeTensor对象的2元组,第一个TensorMeta描述返回值的第一个元素等。

Python 代码:

def add_one(x):
  return torch.ops.aten(x, 1)

图表:

graph():
  %ph_0 : [#用户=1] = placeholder[target=ph_0]
  %add_tensor : [#用户=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 类型

我们定义一个类型“Pytree-able”,如果它是一个叶类型或包含其他Pytree-able类型的容器类型。

注意:

pytree 的概念与 JAX 文档中 此处 所述的概念相同:

以下类型被定义为叶类型

类型

定义

张量

torch.Tensor

标量

Python中的任何数值类型,包括整数类型、浮点类型和零维张量。

整数

Python 整数(在 C++ 中绑定为 int64_t)

浮点数

Python 浮点数(在 C++ 中绑定作为双精度)

布尔值

Python 布尔值

字符串

Python 字符串

标量类型

torch.dtype

布局

torch.layout

内存格式

torch.memory_format

设备

torch.device

以下类型被定义为容器类型

类型

定义

元组

Python 元组

列表

Python 列表

字典

带有标量键的Python字典

命名元组

Python namedtuple

数据类

必须通过 register_dataclass 进行注册

自定义类

使用 _register_pytree_node 定义的任何自定义类

优云智算