PyTorch 2.0 NNModule 支持¶
作者: Will Constable
torch.compile 对 torch.nn.Module 对象有特殊的处理方式,与它对任意 Python 类的追踪方式不同,目的是通过假设结构来生成更快的代码。
本文档描述了由于这种专业化而出现的一些权衡或边缘情况。
NNModule 钩子支持¶
之前,torch.compile 不支持 nn.Modules 上的钩子,如果在 nn.Modules 上注册了钩子,它们在编译后的程序中会被简单地忽略。确实,许多用户根本不使用 nn.Module 钩子,或者仅将它们用于调试工作流程,但将 nn.Module 钩子与 torch.compile 结合使用有其合理的用例。
通过nn.Module.__call__实现编排的钩子包括_forward_pre_hooks、forward_hooks、_backward_pre_hooks和_backward_hooks,这些钩子将被称为“调用钩子”。这些钩子在torch.compile中部分支持,但有以下限制。
另一类钩子包括 _state_dict_hooks 及其 pre 和 load_ 变体,目前仍不受 torch.compile 支持。
nn.Module.__call__ 钩子使用和限制¶
默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行 forward/pre-forward hooks。如果你在调用 torch.compile 之前安装了 hooks,并且之后没有移除或修改这些 hooks,你的用例应该默认支持。
反向/预反向钩子通常也受支持,但有一些类似的注意事项:目前,在访问backward_hooks字典时,dynamo会发生图中断,这可能通过一些工作来避免。图中断还会影响反向钩子的触发时机,因为图段作为自动求导函数运行,这些函数会在同一时间生成所有梯度。假设dynamo能够在存在反向钩子的情况下不发生图中断,我们仍然期望一系列模块的反向钩子在编译图的反向运行完毕后一起触发。
关于“允许的模块”的钩子 torch.compile 对常见的模块(如 torch.conv)以及难以追踪的模块进行了特殊处理,允许它们在 dynamo 图中以不透明的方式调用,而不是由 dynamo 进行追踪。对于这些模块,钩子目前会触发图中断,使得受影响的模块在 dynamo 外部运行。根据模型的不同,这可能会导致显著的性能下降,因此需要额外的工作来改进这种支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装守卫,从而通过减少守卫执行时间来提高运行时性能,但代价是无法注意到编译后是否有任何钩子字典被更改。
如果你希望在编译后能够移除或修改钩子,并且让torch.compile能够适当地响应(通过重新编译),那么你需要设置skip_nnmodule_hook_guards=False,并预期由于添加的防护措施而导致的运行时开销。
待办事项:确认是否正向/预正向钩子是否正常工作并相应地记录