守卫概述¶
从用户体验的角度来看,TorchDynamo 非常易于使用。用户调用
torchdynamo.optimize 作为注解:
@torchdynamo.optimize(my_compiler)
def fn_foo(bar):
一个完整的示例看起来像这样:
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() 被调用,传入的 FX 图为:")
gm.graph.print_tabular()
return gm.forward # 返回一个 Python 可调用对象
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
这使得 TorchDynamo 能够捕获解释的 Python 帧,获取所有相关信息,并在任何可能的地方加速。加速来自于几个方面,并且可能相当依赖于提供的后端(如上例中的 my_compiler),但本节中重要的是 缓存。缓存本身并不是直接的加速,而是一个关键的启用,可以防止重新编译。我们用 dynamo 挖了一个洞,而缓存使我们能够出来。它使我们能够在保持性能中立的同时启用后端——这是我们加速的真正来源。
即使提供了透传的空操作后端:
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward
我们可以看到TorchDynamo即使在常规Python上也能加速Python执行,而不仅仅是在PyTorch上。
缓存和守卫概述¶
TorchDynamo 通过缓存由 TorchDynamo 转换的用户字节码来运行。当 TorchDynamo 接收到一个用于评估的帧时,它会检查帧中引用的对象是否以某些方式发生了变化,如果没有,TorchDynamo 会读取之前转换的用户字节码来进行评估。在本节中,我们将重点讨论如何识别帧中引用的对象是否发生了变化。这是 TorchDynamo 中一个关键的功能,因为它驱动了整个失效生命周期。这个功能被称为 guards。
在高层次上,流程可以概括如下:
TorchDynamo 接收一个 Python 帧。
它将帧(1)通过指令翻译进行转换。
对于在(2)中捕获的对象,TorchDynamo 创建了跟踪对象,这些对象是:
跟踪在一个输出图上,这是一个torch.fx.Tracer的内部特化
守卫
TorchDynamo 处理在 (3) 中创建的守卫对象,将其转换为生成的 Python 函数,check_fn,与一段代码相关联。
每当再次遇到此代码时,会评估check_fn - 如果check_fn通过并评估为True,TorchDynamo会将缓存中的代码和此处遇到的代码识别为相同,并且可以安全使用。如果失败并评估为False,TorchDynamo会将缓存中的代码识别为无效,并可以通过重新编译或图中断来丢弃,以便为新条目腾出空间。
Python 帧评估与 PEP 523¶
TorchDynamo 的功能基于 PEP 523。
TorchDynamo 通过使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装了一个帧评估函数。TorchDynamo 有一个钩子,Python 可以在评估期间将控制权交还给我们。
我们安装的函数是 convert_frame 或
convert_frame_assert 在 nopython=True 的情况下,但暂时忽略这个细微差别,让我们来看看 convert_frame_assert,
因为 convert_frame 代理到它。
我们可以在 torch/_dynamo/convert_frame.py 中找到该函数,其签名如下:
def convert_frame_assert(compiler_fn: Callable, one_graph=True):
此函数包装了Python调用TorchDynamo的入口点,并带有一个帧:
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
这个函数的功能如下:
检查是否之前见过这个
代码(参见:f_code 这里),如果是则提前退出。检查代码是否为不支持的情况。
检查
cache_size(上述第二个参数)是否超过了配置中定义的限制cache_size_limit。如果超过了,函数将丢弃该帧并记录警告。这有助于避免帧的持续重新编译,因为通常这意味着该帧以意外的方式处于热状态,缓存它会产生不必要的开销,因为它很可能在下一次遇到时被逐出。传递帧以及一个通过字节码转换创建
InstructionTranslator的函数,通过transform_code_object。这里在幕后发生了一些关键的事情:新代码通过
transform_code_object生成。通过
InstructionTranslator生成了一个名为output的FX跟踪器。这可能会有些令人困惑,因为InstructionTranslator并不是一个fx跟踪器,但它存储在一个名为tracer的变量中,并且它的输出确实是一个fx跟踪器。该函数生成防护措施并将其存储在
output上方。该函数生成
output_instructions并将其存储在output上。该函数将新产生的转换后的代码映射到最初从帧中读取的初始代码。这个映射值得记住,我们将在下面讨论防护失败时多次提到它。
使用来自4.1的转换代码和来自4.3的防护措施, 该函数生成一个受防护的代码。
现在我们已经学习了框架评估,让我们回顾一下
InstructionTranslator,看看它是如何将我们传递给它的框架转换为TorchDynamo内部类型的。
指令翻译器¶
InstructionTranslator 做了很多事情!我们不会详细介绍它所做的所有事情,但最重要的是,对于本文档而言,它生成一个 symbolic_locals 的映射,该映射维护从帧的 f_locals 到 TorchDynamo 内部变量对象的映射(稍后会详细介绍这些内容)。symbolic_locals 是通过遍历帧的局部变量来填充的:
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
这里的重要组件是对 VariableBuilder 的调用。VariableBuilder 的调用实现代理到一个名为 _wrap 的函数,该函数依次构造 VariableTracker 的实例并在其上调用 make_guards。稍后会详细介绍。
这种映射至关重要,因为每个变量都有关联的守卫,这些守卫随后被传递给self.output,即OutputGraph的实例,一个fx追踪器,如上文4.2节所述。如果你还记得,这个OutputGraph,存储在一个名为output的变量中,是我们存储守卫的地方,这些守卫在被传递之前被存储,以便成为GuardedCode。
InstructionTranslator 是如何做到这一点的?其核心是一个循环,该循环驱动一个名为 step 的函数。
step 就是这样 - 一个单一的处理步骤,接受一个指令并对其进行某种操作。
注意
这些是由 TorchDynamo 的
transform_code_object 处理的实际指令,非常酷。
注意
本节有意跳过 dis.get_instructions的细节。
对于上述示例,以下是一个可能的几个指令的片段:
指令(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
指令(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
指令(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
这是该函数的核心功能。请查看 opname,
然后查看来自 step 内部的这个小片段;
if not hasattr(self, inst.opname):
unimplemented(f"缺少: {inst.opname}")
getattr(self, inst.opname)(inst)
正如我们所见,该函数检查当前类,即InstructionTranslator,是否有一个与操作符名称匹配的属性集(例如,LOAD_CONST)。如果有,该函数会调用它,并传递整个指令对象。如果没有,该函数会将帧丢弃为未实现。
对于 LOAD_CONST 示例,我们可以看到我们确实支持它,定义相对简单:
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
我们可以看到,这个函数创建了一个新的类实例
ConstantVariable,在我们的例子中,值为-1,然后将其压入栈中。
有几十种这样的方法 - 请参阅 symbolic_convert.py 以获取所有这些方法。通常,我们会尽可能多地实现与 Python 字节码指令匹配的方法。
在 step 之后的逻辑和调用 VariableBuilder 的逻辑中,我们现在有很多 VariableTracker,当然,我们也已经讨论了很多关于创建 guards 的内容。让我们深入了解什么是 Variables,并更接近理解 guards。
变量¶
一个 ConstantVariable 是 VariableTracker 的一个实例。
VariableTracker 表示一个被跟踪的 Python 局部变量或栈值。
当涉及到在TorchDynamo中表示一个对象时,
VariableTracker 确实如其名 - 它跟踪给定的变量。
它是一个非常灵活的类,但有一些要点需要注意:
它通过以下方式管理底层对象的
guard关系:make_guardreplace_guardsadd_guard(s)propagate-propagate(*vars: List[List["VariableTracker"]])- 也许是最重要的,因为它结合了所有提供的VariableTracker实例中的守卫。它访问这些守卫并将它们合并到自身中。
它作为底层对象的代理,为TorchDynamo实现方法,以获取有关跟踪对象的信息:
调用方法调用函数python_typeas_proxyis/as_python_proxy
它存储了类型为
Source的变量source,来自torchdynamo/source.py。这种源类型是一个相对独立的类,帮助我们组织和记录原始源的来源,并提供一些便利方法, 比如获取名称,以及对我们来说非常重要的生成守卫。
而这个类(VariableTracker)是围绕子类化构建的,
介于一个完整的抽象基类和完全实现的类之间
- 它让许多方法抛出 NotImplementedError - 依赖于
子类。请参阅 torchdynamo/variables/ 以获取所有子类以实现
合同和自定义行为。
根据我们现在所知,我们可以看到一个来自 dis 指令的示例,BUILD_TUPLE:
BUILD_TUPLE(count)创建一个元组,从栈中消耗count个元素,并将生成的元组压入栈中。
在我们的例子中,由于我们创建Instruction对象的方式,我们的签名会有一点不同,但大致思路是一样的。我们不是传入count,而是传入一个带有一些额外簿记的对象,当然,我们还需要处理将普通的旧python对象转换为TorchDynamo的概念:
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
这段代码的作用如下:
该函数读取
argval,在这种情况下,类似于等效指令的 pydoc 中的counts。函数
popn这些项,在这种情况下,签名是def popn(self, n: int) -> List[TensorVariable]:这暗示了一个 潜在的契约 - 我们正在返回TensorVariables。如果我们 仔细查看symbolic_convert.py和InstructionTranslatorBase/InstructionTranslator,我们看到 唯一被推入和弹出我们的堆栈的是VariableTracker。
函数调用
VariableTracker.propagate。这会从栈中弹出的每个项目中获取守卫,并递归遍历并将所有守卫组合到options中:py return { "guards": guards, }然后,该函数创建一个新的
VariableTracker实例,TupleVariable由items和options组成。这使得我们能够从构成新TupleVariable的items中安装所有适当的保护措施。
注意
第一批守卫是从哪里来的?传播是一种很好的技术,但我们需要在传播之前创建一些东西。VariableBuilder 在创建 VariableTracker 实例时调用 make_guards,这些实例来自 f_locals。这反过来又调用 source,让它创建守卫。
经过这一切之后,字节码翻译已经完成,我们离生成GuardedCode又近了一步。我们现在了解了局部变量如何成为VariableTracker,指令是如何处理的,以及在创建时在哪里调用了保护。在我们能够了解代码和保护如何结合成一个GuardedCode对象之前,我们需要深入研究一下上面的make_guard和source.make_guard调用。然后我们可以理解,当我们与VariableTracker实例一起创建保护时,发生了什么。
制作守卫¶
Guards 只是 Python 对象,属于 Guard 类。让我们更详细地看看它们。
查看数据类的定义(因此,构造函数签名),我们可以看到它有一个名称、一个源和一个创建函数。
@dataclasses.dataclass
class Guard:
name: str
source: GuardSource
create_fn: Callable
名称应为变量的名称。
这里的源是一个枚举,指示防护属于哪种类型的源。
注意
不要与Source以及source.py中的其他类型混淆,这些类型存储在VariableTracker中。
create_fn 提供了从简单的数据类过渡到实际生成有效Python代码的主要功能,以便在调用之间了解事物是否发生了变化,以及我们是否可以安全地从代码缓存中读取。
获取 guard 实例的最常见代码路径是通过 VariableTracker 上的 make_guards。
make_guards -> source.make_guard -> return Guard(self.name(), self.guard_source(), fn)
或者,在一个具体的例子中:
...
elif istype(value, range):
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
return RangeVariable(value=value, guards=guards)
由于在构造此VariableTracker时已设置了source,因此这里只需要提供fn,GuardBuilder.EQUALS_MATCH到create_fn字段。
这个 create_fn 必须是 GuardBuilder 上的一个方法。这样做的理由在我们下一步中会变得显而易见。一旦我们为某个帧创建了所有的守卫,我们就会转到 CheckFunctionManager 和 compile_check_fn。
在convert_frame函数能够生成一个GuardedCode之前,
它需要运行CheckFunctionManager,并带上所有的保护措施,以生成一个check_fn,
然后这个check_fn将与代码一起传递到GuardedCode中。
这是我们存储在我们的缓存条目中的同一个check_fn,也是我们运行以确定是否检索存储的代码的同一个check_fn。
作为参考,这里是该代码:
static CacheEntry *create_cache_entry(CacheEntry *next,
PyObject *guarded_code) {
CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
我们现在知道了一个check_fn函数是如何使用的,以及谁创建了它,以及它由什么组成,但我们还不知道的是它是如何实现的。一个Guard对象列表是如何变成我们稍后可以运行的函数的?
首先,我们迭代这些守卫:
for guard in sorted(guards or [], key=Guard.sort_key):
if not config.guard_nn_modules and guard.is_nn_module():
continue
guard.create(local_builder, global_builder)
调用 guard.create 会运行我们在上面的 Guard 类中设置的 create_fn(不要与我们在生产中正在处理的 check_fn 混淆,名称相似,可能会有些混淆)。在我们的示例中,我们的 create_fn 是 GuardBuilder.EQUALS_MATCH。因此我们现在正在调用它,传入 self,即 guard 本身。
签名是:def EQUALS_MATCH(self, guard: Guard):
在函数内部,我们可以使用 name 守卫来获取原始对象,查询其数据和类型信息,
这反过来又让我们到达了最重要的部分:追加代码。
最简单的情况下,EQUALS_MATCH 只添加一行代码:
self.code.append(f"{ref} == {val!r}")。其中 ref 是变量的名称,而 val 是值。它可能会生成如下代码:
y == 2
这是一个基本示例。但如果我们添加一些其他类型的 GuardBuilder
函数,然后将它们全部用
and 连接起来(就像我们做的那样),我们可能会得到类似这样的结果:
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
这段代码执行的操作如下:
检查
.valid类型ID检查
值检查
张量检查
这成为我们代码的核心部分 check_fn,它将在我们下次遇到此代码时被评估。然后它将检查:
这段代码仍然有效吗?
如果 (1),
y仍然具有94367738391392类型吗?如果 (2),
y仍然是 2 吗?如果 (3),让我们检查张量
x是否以某种特定方式发生了变化。
如果这些条件仍然成立,那么我们可以使用与此check_fn一起缓存的代码。
注意
关于这一过程如何以及在何处发生的更深入探讨,您可以阅读 static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) { 的
_eval_frame.c 部分。
如果没有,那么,我们可以继续重新编译代码,并将其与这段代码一起存储在缓存中,以及一个新的check_fn,同样在后续的帧中进行检查。
在GuardBuilder上有许多其他类似的函数,它们会被合并成有时非常庞大的字符串,然后作为Python代码进行评估并存储到check_fn中。上面的例子展示了一个简单的案例。要更好地理解这一功能,请阅读GuardBuilder上的其他函数,或者更好的是,在compile_check_fn中转储code变量,以查看生成了什么内容,特别是在较大的实际模型中。
概述¶
在本节中,我们已经回顾了:
关于弱引用(以及可能即将成为NN模块失效)的
.valid和失效的角色。C++端守卫函数(
___check_type_id、___check_tensors等)的运作方式。当守卫失败时会发生什么。
如果我们生成无效的保护代码会发生什么。
我们介绍了在TorchDynamo上下文中用户提供的代码如何被内部跟踪和记录,组织成VariableTracker、Source,随后是Guard,以及这些Guard如何指导缓存条目的选择和失效,当处理Python代码时。