假张量¶
代码: fake_tensor.py
动机¶
在进行Dynamo符号评估和编译器传递时,我们通常希望能够运行张量操作以了解输出的大小/数据类型/设备,而不实际运行这些操作(或破坏现有的张量),因为这会更慢(如果你进行了大量计算)并且会占用大量内存(如果编译器在编译程序时需要使用GPU内存,这是不好的)。一个假张量在所有方面都像一个真正的张量,只是它实际上没有任何数据。例如,当我们进行Dynamo跟踪时,我们需要跟踪用户张量代码并回答关于中间结果的问题(例如,如果用户对中间张量进行条件判断)。没有假张量,我们将无法为这些查询提供准确的信息。
同样地,假设你想为一个张量存储元数据,例如,在一个FX IR节点上(meta[‘val’])。你可以直接在节点上存储一个假张量,这将为你提供张量所需的所有元数据,包括你可能不会处理的细微之处(例如,别名关系)。
整体架构¶
所有假张量都与一个FakeTensorMode相关联。因为假张量的主要用例是对真实张量进行分析,所以一般的流程是,你有一组真实张量,你分配一个FakeTensorMode,然后你使用from_real_tensor将所有这些真实张量转换为假张量,然后你对这些假张量进行操作。特别是,FakeTensorMode持久地维护一个备忘录表,将张量(和存储)映射到相同的存储。如果你多次假化同一个张量,你将得到相同的假张量;如果你假化两个相互别名的张量,你将得到两个别名相同假存储的假张量。FakeTensors是张量子类,所以如果你对它们进行操作,你会自动得到一个假张量,但通常你会希望在FakeTensorMode激活的情况下对假张量进行操作(例如,如果你正在运行一个FX pass);张量操作会自动打开假张量模式并再次尝试。
一个假的张量被表示为一个元张量的__torch_dispatch__张量子类。这意味着在底层,假张量是元设备张量;它们然后使用额外的扩展性钩子,特别是dispatch_device,来谎报张量的实际设备是什么。这是早期假张量中更容易出错的部分之一:有时,假张量在谎报自己是CPU/CUDA等方面做得太好了,你会最终得到一个CPU内核被调用,而假张量试图解引用数据指针,这显然不会起作用。如果你在假张量代码中遇到段错误,这是你应该首先检查的事情:C++回溯是在CPU内核(意外!)还是元内核(预期!)中。元内核就像一个真正的内核,但它所做的只是分配输出,它不进行任何数据计算。
张量子类必须定义如何实现各种操作。以下是通用伪张量配方:
在输入的假张量上运行元内核,将它们重新解释为元张量。这是通过一个魔法上下文管理器 in_kernel_invocation_manager 完成的,它指示所有 PyTorch 将假张量视为其基础元张量,而不是将假张量“解包”为元张量(假张量是元张量)。假张量以这种方式表示,以避免必须保持两组元数据同步(元张量的元数据和假张量的元数据);“是一个”关系确保只有一份规范的元数据副本。
如果你是一个工厂函数,你将使用 device='meta' 调用底层的工厂函数。
将生成的元张量转换为假张量,计算张量应输出的设备(这通常很简单,但有时并非如此,例如,CPU标量提升,或设备转换操作。)
API: 重要部分¶
非PT2用法(更多示例请查看 test/test_fake_tensor.py):
# 创建一个假的模式
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
# 伪造一些真实的张量
fake_x = fake_mode.from_real_tensor(x)
with fake_mode:
# 对假张量进行一些操作
fake_y = fake_x * 2
# 工厂操作在上下文管理器中自动被伪造
fake_z = torch.empty(20)
问:为什么你的输入是真实张量?
A: 在PT2上下文中,这是因为您通常是即时编译的,因此对于您正在编译的图形的所有输入,您已经拥有了“真实”的输入,因为您在执行程序时进行编译。
PT2 预先 AOTAutograd 使用(这是不寻常的,你可能不想这样做):
# 假模式未启用!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
fake_args = [fake_mode.from_real_tensor(arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...
detect_fake_mode 将会搜索多个位置以尝试找到与生命周期相关的“假”张量模式。通常它会从跟踪上下文中提取出来。
PT2 后AOTAutograd 使用:
# 假模式已启用!example_inputs 通常已经是假的 # TODO: 我们可能需要更改这个 # 仍然这样做以访问假模式 fake_mode = detect_fake_mode(example_inputs) # 但通常情况下你不需要打开它
其他有用的内容:
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
# 假模式在这里被禁用,你可以进行真实的张量计算
你什么时候可能想要禁用假张量模式?通常你不会想要这样做。我们发现它有用的一个小众情况是实现假张量上的常量传播:在这种情况下,即使我们在假张量模式下,我们也需要进行一些实际的张量计算。
FakeTensorProp
from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# 这将在所有FX节点上填充meta['val'],并使用一个假张量
# 或者如果你有一个预先存在的假模式,你应该使用它
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# 如果你已经有假输入,也可以使用propagate_dont_convert_inputs
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
详情¶
是否自动转换? 最初,如果您尝试在 FakeTensorMode 区域内对真实张量进行计算,FakeTensorMode 不会自动伪造这些真实张量。这样做的动机是为了防止以下陷阱:
with FakeTensorMode():
real_tensor.t_()
这段代码应该做什么?如果我们实际上修改了真实张量的元数据,那将会令人惊讶。但与此同时,没有任何明显的机会来创建一个FakeTensor。因此,我们保守地决定让这段代码引发一个错误:“在FakeTensorMode中使用非Fake Tensor输入调用操作符尚不支持。请先将所有张量转换为FakeTensors。”
这个错误在实践中相当烦人。例如,假设你有一个真实的 nn.Module,并且你想通过它传递假张量。你需要以某种方式伪造 nn.Module。这促使了 FakeCopyMode 的产生。
最终,我们放弃了并添加了自动伪造功能。然而,这在许多使用FakeTensorMode的情况下仍然默认未启用。
假张量上的元数据突变 如果你有一个假张量,并且对其调用 t_() 方法,假张量上的元数据会发生变化。这在表面上看起来是合理的,但有时你可能希望将假张量也作为 FX 节点上的元数据存储;突变假张量是不好的,因为这会使旧的元数据失效!
事实上,这里存在一个根本的矛盾,即假张量保持了关于张量的极其精确的元数据,甚至包括对象身份。如果在FX图中的对象元数据随时间变化,实际上并没有任何方法来表示这种随时间的变化。大多数情况下,我们进行的重要FX分析都是在功能化的图上进行的,这些图没有这个问题,但偶尔你需要对非功能化的图进行分析。也许将假张量放在meta['val']中是一个错误。
关于张量子类¶
Fake tensor 使用了子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用了与 fake tensor 关联的 FakeTensorMode,然后重新分派(依赖 FakeTensorMode 来完成繁重的工作)。如果 fake tensor 操作接收到一个它无法识别的子类参数,它将返回 NotImplemented,给其他子类一个机会先运行(希望将其简化为普通张量操作),然后再尝试一次。这可能会导致无限循环。
每个操作符是如何实现的?¶
不幸的是,任何给定的操作符可能在相当复杂的一系列地方被实现。一些需要了解的重要情况包括:
如果元素数量非常少,张量子类支持有限的常量传播(这有助于处理一些我们立即在这些张量上调用 item() 的情况。)
我们为某些运算符提供了快速路径实现,这些实现完全在假张量中完成,出于性能原因。
如果你使用 @custom_op 生成一个自定义张量,这些将直接将 impl_abstract 注册到假张量。
Fake tensor 本身对设备转换操作有一些硬编码的特殊情况。
如果没有元实现也没有任何分解,我们将生成真实的全零填充张量,并尝试直接运行该操作符以找出结果会是什么。如果操作符尝试使用数据进行索引,这可能会导致段错误,因此我们默认不会为自定义操作开启此功能。
转换器是如何工作的?¶
因为假张量在非常敏感于张量确切属性的情况下使用,假张量在进行转换时非常小心,保留了叶性、requires_grad性、别名以及其他许多属性。大部分繁重的工作都在MetaConverter中完成。
性能特征¶
你可能会认为假张量很快,因为它们不进行任何张量计算。但在小张量尺寸下,我们实际上完全受开销限制,而且,假张量是用Python实现的,我们通常需要做很多工作来完成一个张量操作(因为它们是作为分解实现的)。因此,假张量在实践中实际上相当慢,尤其是在涉及符号形状时。目前,我们在假张量中有两个重要的快速路径,这在实践中产生了很大的差异:
逐点操作不会通过PrimTorch分解,而是我们手动编码了它们的传播规则。
如果可能的话,我们应该这样做。
假的张量的假张量?¶
目前有兴趣将伪造的张量作为用户输入发送到PT2堆栈中,这意味着我们需要能够创建一个伪造张量的伪造张量。目前这并不真正支持,但也许这并不会太难实现。
与动态形状的交互¶
每个FakeTensorMode都包含一个ShapeEnv,用于跟踪所有符号形状信息。它们的生存期通常是绑定的:它们一起生存和消亡。
因为 FakeTensorMode 有一个 ShapeEnv(但元实现没有),所以依赖数据的元函数需要分配一个无支持的 SymInt,这些函数位于 fake tensor 中。Fake tensor 还负责记忆无支持的 SymInt,因此,例如,如果你对同一个 fake tensor 调用两次 nonzero(),你会得到相同的有符号大小。