Shortcuts

守卫概述

从用户体验的角度来看,TorchDynamo 非常易于使用。用户调用 torchdynamo.optimize 作为注解:

@torchdynamo.optimize(my_compiler)
def fn_foo(bar):

一个完整的示例看起来像这样:

from typing import List
import torch
from torch import _dynamo as torchdynamo

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() 被调用,传入的 FX 图为:")
    gm.graph.print_tabular()
    return gm.forward  # 返回一个 Python 可调用对象

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

这使得 TorchDynamo 能够捕获解释的 Python 帧,获取所有相关信息,并在任何可能的地方加速。加速来自于几个方面,并且可能相当依赖于提供的后端(如上例中的 my_compiler),但本节中重要的是 缓存。缓存本身并不是直接的加速,而是一个关键的启用,可以防止重新编译。我们用 dynamo 挖了一个洞,而缓存使我们能够出来。它使我们能够在保持性能中立的同时启用后端——这是我们加速的真正来源。

即使提供了透传的空操作后端:

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return gm.forward

我们可以看到TorchDynamo即使在常规Python上也能加速Python执行,而不仅仅是在PyTorch上。

缓存和守卫概述

TorchDynamo 通过缓存由 TorchDynamo 转换的用户字节码来运行。当 TorchDynamo 接收到一个用于评估的帧时,它会检查帧中引用的对象是否以某些方式发生了变化,如果没有,TorchDynamo 会读取之前转换的用户字节码来进行评估。在本节中,我们将重点讨论如何识别帧中引用的对象是否发生了变化。这是 TorchDynamo 中一个关键的功能,因为它驱动了整个失效生命周期。这个功能被称为 guards

在高层次上,流程可以概括如下:

  1. TorchDynamo 接收一个 Python 帧。

  2. 它将帧(1)通过指令翻译进行转换。

  3. 对于在(2)中捕获的对象,TorchDynamo 创建了跟踪对象,这些对象是:

    • 跟踪在一个输出图上,这是一个torch.fx.Tracer的内部特化

    • 守卫

  4. TorchDynamo 处理在 (3) 中创建的守卫对象,将其转换为生成的 Python 函数,check_fn,与一段代码相关联。

  5. 每当再次遇到此代码时,会评估check_fn - 如果check_fn通过并评估为True,TorchDynamo会将缓存中的代码和此处遇到的代码识别为相同,并且可以安全使用。如果失败并评估为False,TorchDynamo会将缓存中的代码识别为无效,并可以通过重新编译或图中断来丢弃,以便为新条目腾出空间。

Python 帧评估与 PEP 523

TorchDynamo 的功能基于 PEP 523

TorchDynamo 通过使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装了一个帧评估函数。TorchDynamo 有一个钩子,Python 可以在评估期间将控制权交还给我们。

我们安装的函数是 convert_frameconvert_frame_assertnopython=True 的情况下,但暂时忽略这个细微差别,让我们来看看 convert_frame_assert, 因为 convert_frame 代理到它。

我们可以在 torch/_dynamo/convert_frame.py 中找到该函数,其签名如下:

def  convert_frame_assert(compiler_fn: Callable, one_graph=True):

此函数包装了Python调用TorchDynamo的入口点,并带有一个帧:

def  _convert_frame_assert(frame: types.FrameType, cache_size: int):

这个函数的功能如下:

  1. 检查是否之前见过这个代码(参见:f_code 这里),如果是则提前退出。

  2. 检查代码是否为不支持的情况。

  3. 检查cache_size(上述第二个参数)是否超过了配置中定义的限制cache_size_limit。如果超过了,函数将丢弃该帧并记录警告。这有助于避免帧的持续重新编译,因为通常这意味着该帧以意外的方式处于热状态,缓存它会产生不必要的开销,因为它很可能在下一次遇到时被逐出。

  4. 传递帧以及一个通过字节码转换创建InstructionTranslator的函数,通过transform_code_object。这里在幕后发生了一些关键的事情:

    1. 新代码通过 transform_code_object 生成。

    2. 通过InstructionTranslator生成了一个名为output的FX跟踪器。这可能会有些令人困惑,因为InstructionTranslator并不是一个fx跟踪器,但它存储在一个名为tracer的变量中,并且它的输出确实是一个fx跟踪器。

    3. 该函数生成防护措施并将其存储在output上方。

    4. 该函数生成output_instructions并将其存储在 output上。

    5. 该函数将新产生的转换后的代码映射到最初从帧中读取的初始代码。这个映射值得记住,我们将在下面讨论防护失败时多次提到它。

  5. 使用来自4.1的转换代码和来自4.3的防护措施, 该函数生成一个受防护的代码

现在我们已经学习了框架评估,让我们回顾一下 InstructionTranslator,看看它是如何将我们传递给它的框架转换为TorchDynamo内部类型的。

指令翻译器

InstructionTranslator 做了很多事情!我们不会详细介绍它所做的所有事情,但最重要的是,对于本文档而言,它生成一个 symbolic_locals 的映射,该映射维护从帧的 f_locals 到 TorchDynamo 内部变量对象的映射(稍后会详细介绍这些内容)。symbolic_locals 是通过遍历帧的局部变量来填充的:

self.symbolic_locals = collections.OrderedDict(
    (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
    for k in vars
    if k in f_locals
)

这里的重要组件是对 VariableBuilder 的调用。VariableBuilder 的调用实现代理到一个名为 _wrap 的函数,该函数依次构造 VariableTracker 的实例并在其上调用 make_guards。稍后会详细介绍。

这种映射至关重要,因为每个变量都有关联的守卫,这些守卫随后被传递给self.output,即OutputGraph的实例,一个fx追踪器,如上文4.2节所述。如果你还记得,这个OutputGraph,存储在一个名为output的变量中,是我们存储守卫的地方,这些守卫在被传递之前被存储,以便成为GuardedCode

InstructionTranslator 是如何做到这一点的?其核心是一个循环,该循环驱动一个名为 step 的函数。

step 就是这样 - 一个单一的处理步骤,接受一个指令并对其进行某种操作

注意

这些是由 TorchDynamo 的 transform_code_object 处理的实际指令,非常酷。

注意

本节有意跳过 dis.get_instructions的细节。

对于上述示例,以下是一个可能的几个指令的片段:

指令(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
指令(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
指令(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)

这是该函数的核心功能。请查看 opname, 然后查看来自 step 内部的这个小片段;

if not hasattr(self, inst.opname):
    unimplemented(f"缺少: {inst.opname}")
getattr(self, inst.opname)(inst)

正如我们所见,该函数检查当前类,即InstructionTranslator,是否有一个与操作符名称匹配的属性集(例如,LOAD_CONST)。如果有,该函数会调用它,并传递整个指令对象。如果没有,该函数会将帧丢弃为未实现。

对于 LOAD_CONST 示例,我们可以看到我们确实支持它,定义相对简单:

def LOAD_CONST(self, inst):
    self.push(ConstantVariable(value=inst.argval))

我们可以看到,这个函数创建了一个新的类实例 ConstantVariable,在我们的例子中,值为-1,然后将其压入栈中。

有几十种这样的方法 - 请参阅 symbolic_convert.py 以获取所有这些方法。通常,我们会尽可能多地实现与 Python 字节码指令匹配的方法。

step 之后的逻辑和调用 VariableBuilder 的逻辑中,我们现在有很多 VariableTracker,当然,我们也已经讨论了很多关于创建 guards 的内容。让我们深入了解什么是 Variables,并更接近理解 guards。

变量

一个 ConstantVariableVariableTracker 的一个实例。 VariableTracker 表示一个被跟踪的 Python 局部变量或栈值。

当涉及到在TorchDynamo中表示一个对象时, VariableTracker 确实如其名 - 它跟踪给定的变量。 它是一个非常灵活的类,但有一些要点需要注意:

  • 它通过以下方式管理底层对象的guard关系:

    • make_guard

    • replace_guards

    • add_guard(s)

    • propagate - propagate(*vars: List[List["VariableTracker"]]) - 也许是最重要的,因为它结合了所有提供的VariableTracker实例中的守卫。它访问这些守卫并将它们合并到自身中。

  • 它作为底层对象的代理,为TorchDynamo实现方法,以获取有关跟踪对象的信息:

    • 调用方法

    • 调用函数

    • python_type

    • as_proxy

    • is/as_python_proxy

  • 它存储了类型为 Source 的变量 source,来自 torchdynamo/source.py。这种源类型是一个相对独立的类,帮助我们组织和记录原始源的来源,并提供一些便利方法, 比如获取名称,以及对我们来说非常重要的生成守卫。

而这个类(VariableTracker)是围绕子类化构建的, 介于一个完整的抽象基类和完全实现的类之间 - 它让许多方法抛出 NotImplementedError - 依赖于 子类。请参阅 torchdynamo/variables/ 以获取所有子类以实现 合同和自定义行为。

根据我们现在所知,我们可以看到一个来自 dis 指令的示例,BUILD_TUPLE

BUILD_TUPLE(count) 创建一个元组,从栈中消耗count个元素,并将生成的元组压入栈中。

在我们的例子中,由于我们创建Instruction对象的方式,我们的签名会有一点不同,但大致思路是一样的。我们不是传入count,而是传入一个带有一些额外簿记的对象,当然,我们还需要处理将普通的旧python对象转换为TorchDynamo的概念:

def BUILD_TUPLE(self, inst):
    items = self.popn(inst.argval)
    options = VariableTracker.propagate(items)
    self.push(TupleVariable(items, **options))

这段代码的作用如下:

  1. 该函数读取 argval,在这种情况下,类似于等效指令的 pydoc 中的 counts

  2. 函数 popn 这些项,在这种情况下,签名是 def  popn(self, n: int) -> List[TensorVariable]: 这暗示了一个 潜在的契约 - 我们正在返回 TensorVariables。如果我们 仔细查看 symbolic_convert.pyInstructionTranslatorBase/InstructionTranslator,我们看到 唯一被推入和弹出我们的堆栈的是 VariableTracker

  1. 函数调用 VariableTracker.propagate。这会从栈中弹出的每个项目中获取守卫,并递归遍历并将所有守卫组合到 options 中:py  return {      "guards": guards,  }

  2. 然后,该函数创建一个新的 VariableTracker 实例, TupleVariableitemsoptions 组成。这使得我们能够从构成新 TupleVariableitems 中安装所有适当的保护措施。

注意

第一批守卫是从哪里来的?传播是一种很好的技术,但我们需要在传播之前创建一些东西。VariableBuilder 在创建 VariableTracker 实例时调用 make_guards,这些实例来自 f_locals。这反过来又调用 source,让它创建守卫。

经过这一切之后,字节码翻译已经完成,我们离生成GuardedCode又近了一步。我们现在了解了局部变量如何成为VariableTracker,指令是如何处理的,以及在创建时在哪里调用了保护。在我们能够了解代码和保护如何结合成一个GuardedCode对象之前,我们需要深入研究一下上面的make_guardsource.make_guard调用。然后我们可以理解,当我们与VariableTracker实例一起创建保护时,发生了什么。

制作守卫

Guards 只是 Python 对象,属于 Guard 类。让我们更详细地看看它们。

查看数据类的定义(因此,构造函数签名),我们可以看到它有一个名称、一个源和一个创建函数。

@dataclasses.dataclass
class Guard:
    name: str
    source: GuardSource
    create_fn: Callable

名称应为变量的名称。

这里的源是一个枚举,指示防护属于哪种类型的源。

注意

不要与Source以及source.py中的其他类型混淆,这些类型存储在VariableTracker中。

create_fn 提供了从简单的数据类过渡到实际生成有效Python代码的主要功能,以便在调用之间了解事物是否发生了变化,以及我们是否可以安全地从代码缓存中读取。

获取 guard 实例的最常见代码路径是通过 VariableTracker 上的 make_guardsmake_guards -> source.make_guard -> return Guard(self.name(), self.guard_source(), fn)

或者,在一个具体的例子中:

...
elif istype(value, range):
    guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
    return RangeVariable(value=value, guards=guards)

由于在构造此VariableTracker时已设置了source,因此这里只需要提供fnGuardBuilder.EQUALS_MATCHcreate_fn字段。

这个 create_fn 必须是 GuardBuilder 上的一个方法。这样做的理由在我们下一步中会变得显而易见。一旦我们为某个帧创建了所有的守卫,我们就会转到 CheckFunctionManagercompile_check_fn

convert_frame函数能够生成一个GuardedCode之前, 它需要运行CheckFunctionManager,并带上所有的保护措施,以生成一个check_fn, 然后这个check_fn将与代码一起传递到GuardedCode中。 这是我们存储在我们的缓存条目中的同一个check_fn,也是我们运行以确定是否检索存储的代码的同一个check_fn。 作为参考,这里是该代码:

static CacheEntry *create_cache_entry(CacheEntry *next,
                                      PyObject *guarded_code) {
  CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
  DEBUG_NULL_CHECK(e);
  e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
  NULL_CHECK(e->check_fn);
  e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
  NULL_CHECK(e->code);
  e->next = next;
  return e;
}

我们现在知道了一个check_fn函数是如何使用的,以及谁创建了它,以及它由什么组成,但我们还不知道的是它是如何实现的。一个Guard对象列表是如何变成我们稍后可以运行的函数的?

首先,我们迭代这些守卫:

for guard in sorted(guards or [], key=Guard.sort_key):
    if not config.guard_nn_modules and guard.is_nn_module():
        continue
    guard.create(local_builder, global_builder)

调用 guard.create 会运行我们在上面的 Guard 类中设置的 create_fn(不要与我们在生产中正在处理的 check_fn 混淆,名称相似,可能会有些混淆)。在我们的示例中,我们的 create_fnGuardBuilder.EQUALS_MATCH。因此我们现在正在调用它,传入 self,即 guard 本身。

签名是:def EQUALS_MATCH(self, guard: Guard):

在函数内部,我们可以使用 name 守卫来获取原始对象,查询其数据和类型信息, 这反过来又让我们到达了最重要的部分:追加代码。

最简单的情况下,EQUALS_MATCH 只添加一行代码: self.code.append(f"{ref} == {val!r}")。其中 ref 是变量的名称,而 val 是值。它可能会生成如下代码:

y == 2

这是一个基本示例。但如果我们添加一些其他类型的 GuardBuilder 函数,然后将它们全部用 and 连接起来(就像我们做的那样),我们可能会得到类似这样的结果:

___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)

这段代码执行的操作如下:

  1. 检查 .valid

  2. 类型ID检查

  3. 值检查

  4. 张量检查

这成为我们代码的核心部分 check_fn,它将在我们下次遇到此代码时被评估。然后它将检查:

  1. 这段代码仍然有效吗?

  2. 如果 (1),y 仍然具有 94367738391392 类型吗?

  3. 如果 (2),y 仍然是 2 吗?

  4. 如果 (3),让我们检查张量 x 是否以某种特定方式发生了变化。

如果这些条件仍然成立,那么我们可以使用与此check_fn一起缓存的代码。

注意

关于这一过程如何以及在何处发生的更深入探讨,您可以阅读 static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {_eval_frame.c 部分。

如果没有,那么,我们可以继续重新编译代码,并将其与这段代码一起存储在缓存中,以及一个新的check_fn,同样在后续的帧中进行检查。

GuardBuilder上有许多其他类似的函数,它们会被合并成有时非常庞大的字符串,然后作为Python代码进行评估并存储到check_fn中。上面的例子展示了一个简单的案例。要更好地理解这一功能,请阅读GuardBuilder上的其他函数,或者更好的是,在compile_check_fn中转储code变量,以查看生成了什么内容,特别是在较大的实际模型中。

概述

在本节中,我们已经回顾了:

  • 关于弱引用(以及可能即将成为NN模块失效)的.valid和失效的角色。

  • C++端守卫函数(___check_type_id___check_tensors等)的运作方式。

  • 当守卫失败时会发生什么。

  • 如果我们生成无效的保护代码会发生什么。

我们介绍了在TorchDynamo上下文中用户提供的代码如何被内部跟踪和记录,组织成VariableTrackerSource,随后是Guard,以及这些Guard如何指导缓存条目的选择和失效,当处理Python代码时。

优云智算