Shortcuts

动态形状

代码: symbolic_shapes.py

参见: 动态形状手册

动机

深度学习编译器通常只适用于静态形状,也就是说,它们生成的编译程序只能用于单一特定的输入形状配置,如果输入形状发生变化,则必须重新编译。这一假设在当今大多数常见的深度学习模型中表现良好,但在某些情况下,这种假设是不够的:

  • 一些维度,如批次大小或序列长度,可能会有所不同。例如,执行自适应批处理的推理服务将根据其在批处理窗口内收到的请求数量,以不同的批次大小执行推理请求。我们可能还希望仅将可变大小的序列填充到批次内的最大序列长度,该长度可能会因批次而异。

  • 一些模型表现出数据依赖的输出形状,也就是说,它们的输出和中间结果的大小可能取决于实际输入数据,这些数据在不同的运行中可能会有所不同。例如,检测模型可能会首先生成数量可变的潜在边界框,然后运行一个更昂贵的图像识别模型来确定主题是否在边界框内。边界框的数量是数据依赖的。

  • 在处理稀疏表示(如稀疏张量、锯齿张量和图神经网络)时,数据依赖形状的一个特别重要的案例出现了。在这些情况下,要处理的数据量取决于问题的稀疏结构,而这种结构通常会以数据依赖的方式变化。

在支持动态形状时,我们选择不支持动态秩程序,例如,输入张量在维度上发生变化的程序,因为这种模式在现实世界的深度学习程序中很少出现,并且避免了需要对符号形状列表进行归纳推理的需求。

简化的公共API

PyTorch 2.1 中的默认动态行为是:

  • PT2 默认假设一切都是静态的

  • 如果我们因为尺寸变化而重新编译,我们将尝试将该尺寸作为动态尺寸进行重新编译(已变化的尺寸可能在将来再次变化)。这种泛化可能会失败(例如,因为用户代码在相关尺寸上进行了条件分支,或者PT2中缺少动态形状支持)。如果您试图理解为什么PT2对某些代码进行了过度特化,请使用TORCH_LOGS=dynamic运行,并查找“eval”条目,这些条目会说明何时添加了保护以及原因。

  • 如果你事先知道某些内容将是动态的,你可以跳过第一次重新编译,使用 torch._dynamo.mark_dynamic(tensor, dim)。如果你事先知道这个维度可以取的 minmax 值,你可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果你说 torch.compile(dynamic=False),我们将关闭自动动态形状并在每次重新编译时为每个不同的尺寸重新编译。 相反,如果你说 torch.compile(dynamic=True),我们将尝试使所有内容尽可能动态。这对于小型操作符非常有用;如果你尝试在大模型上使用它,它可能会(1)可能导致PT2崩溃,并且(2)由于没有充分的理由而运行缓慢。

守卫模型

在考虑如何为 TorchDynamo 和 TorchInductor 添加对动态形状的支持时,我们做出了一个重要的设计决策:为了重用针对 PyTorch API 编写的 Python/C++ 中的分解和其他现有代码,我们必须能够跟踪动态形状。与可能捕获条件语句两个分支的完全符号系统不同,我们总是选择一个分支,并在假设我们将来只会在这个分支上做出相同选择的情况下进行专门的跟踪。为此,我们为每个符号大小维护一个“提示”,指示其在编译时的具体值(由于 TorchDynamo 是一个即时编译器,它总是知道实际的输入大小。)当我们对张量执行条件判断时,我们只需查阅提示以确定要选择哪个分支。

这大大简化了我们生成的符号形状公式,但意味着我们有一个更复杂的系统来管理防护措施。例如,考虑以下程序:

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将使用 TorchInductor 编译的最终 IR 要么是 torch.cat([x, y]).add(2) 要么是 torch.cat([x, y]).mul(2)(条件已被展平),但要确定我们在哪个分支中,我们需要知道中间变量 z 的大小。因为 TorchDynamo 必须提前知道编译的跟踪是否有效(我们不支持像某些 JIT 编译器那样的保释操作),我们必须能够将 z.size(0) 表示为输入的表达式,即 x.size(0) + y.size(0)。这是通过为 PyTorch 中的所有操作符编写元函数来实现的,这些元函数可以在不实际对节点进行计算的情况下,将大小信息传播到张量的输出。

整体架构

符号形状工作流程:

  1. 当我们开始在Dynamo中编译一个帧时,我们会分配一个ShapeEnv(附加到FakeTensorMode),它用于跟踪符号形状状态。

  2. 我们在入口处为张量分配符号大小(静态或动态是策略决策,有一些控制选项)。

  3. 我们通过操作符传播符号大小,同时维护(1)FX IR,以便我们可以忠实地导出符号计算,以及(2)表示大小变量的Sympy表达式,以便我们可以对它们进行推理。

  4. 当我们基于符号大小进行条件处理时,无论是在Dynamo跟踪中还是在Inductor优化中,我们都会根据条件添加保护措施。这些保护措施可以由Python和C++共同引发。

  5. 这些保护可以对符号变量进行进一步的简化。例如,如果你断言 s0 == 4,我们现在可以将所有出现的 s0 替换为 4

  6. 当我们完成跟踪和优化后,我们将所有这些保护措施与编译后的代码一起安装;只有当所有保护措施都评估为真时,编译后的代码才是可重用的。

重要文件:

  • C++ SymInt API: c10/core/SymInt.h, SymFloat.h, SymBool.h

  • Python SymInt API: torch/__init__.py (查找 SymInt/SymFloat/SymBool)

  • C++ 底层实现: c10/core/SymNodeImpl.h, torch/csrc/utils/python_symnode.h, torch/csrc/jit/python/init.cpp

  • Python 基础设施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要文件:torch/_subclasses/fake_tensor.py, torch/_meta_registrations.py, 分解, PrimTorch 引用

简化的内部API

理解Python类层次结构:

  • SymInt/SymFloat/SymBool: 这些是用户可见的类,用于模拟它们的 int/float/bool 对应类型。如果你将两个 SymInt 相加,我们会给你一个新的 SymInt,它符号化地记录了整数加法的发生。

  • SymNode: 这是内部结构(可通过例如 symint.node 访问),用于保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得表示混合类型操作更加方便。请注意,从技术上讲,您不必从 SymInt 调用 Python 的 SymNode;例如,XLA 的 C++ SymNodeImpl 将代替 SymNode。

  • ShapeEnv: 每个编译上下文的状态,用于跟踪到目前为止我们积累的所有自由符号和保护。每个 SymNode 记录其 ShapeEnv(但反之则不然;只有当 SymNode 参与保护时才会使用它们)。

C++ 也非常相似:

  • c10::SymInt/SymFloat/SymBool: 用户可见的类,用于模拟int/float/bool。

  • c10::SymNode/SymNodeImpl: 类似于 SymNode

  • 在C++中没有ShapeEnv;为了便于调试,整个符号推理装置都在Python中。

当你编写可以使用 make_fx 进行追踪的代码时,它必须能够处理 SymInt/SymFloat/SymBool 在其中流动。动态形状手册 提供了一些关于如何做到这一点的指导。

DimDynamic 策略

符号推理:

  • 值范围

  • Sympy 使用笔记

  • 约束条件

  • DimDynamic/约束

未支持的 SymInts

为了解析控制流,我们检查符号整数的提示(即实际值),以确定要走哪个分支。然而,在某些情况下,我们可能没有提示:当一个大小变量从一个数据依赖操作(如.nonzero().item())中产生时,就会出现所谓的无支持符号整数。在这些符号整数上执行控制流是非法的,因此我们必须在这些操作上进行图中断。

天真地实现,这太严格了:如果你尝试对无支持的符号整数进行任何操作,大多数 PyTorch 程序将立即失败。以下是使其实际工作的最重要的改进:

  • 在创建张量时,PyTorch 会预先计算关于张量的许多数据;例如,如果您使用 empty_strided 创建张量,我们会立即对步幅进行排序,并确定张量是否为非重叠且密集的。排序会产生许多保护措施。然而,更常见的是直接使用像 empty 这样的高级 API 来生成张量,这样可以保证生成一个非重叠且密集的张量。我们修改了 PyTorch,以避免不必要地重新计算这些属性。

  • 即使需要进行复杂的计算,有时某个属性可能从未被实际查询过。将这些预计算属性设为惰性属性,可以让我们避免在没有实际需要时对未支持的符号整数进行保护。

  • 整数张量中的数据通常不被认为是非负的。然而,我们提供了一个API constrain_range,用户可以通过该API指定一个大小在已知的上限和下限之间。

在PT2的未来版本(PT2.1之后),我们将扩展我们的推理系统,以基于使用情况推断出一个无支持的符号整数是类似大小的。例如,如果你将一个.item()调用的结果传递给一个工厂函数,如torch.empty,我们将自动推断该结果是一个大小(因为如果不是,它将会失败。)这个假设将在运行时得到验证,如果未满足,则会引发错误。

优云智算