Shortcuts

torch.compile

torch.compile(model=None, *, fullgraph=False, dynamic=None, backend='inductor', mode=None, options=None, disable=False)[源代码]

使用TorchDynamo和指定的后端优化给定的模型/函数。

具体来说,对于在编译区域内执行的每一帧,我们将尝试编译它并将编译结果缓存到代码对象中以供将来使用。如果先前的编译结果不适用于后续调用(这称为“保护失败”),您可以使用 TORCH_LOGS=guards 来调试这些情况。多个编译结果可以与一帧关联,最多可达 torch._dynamo.config.cache_size_limit,默认值为 64;在此之后,我们将回退到急切模式。请注意,编译缓存是按 代码对象 进行的,而不是按帧进行的;如果您动态创建函数的多个副本,它们将共享相同的代码缓存。

Parameters
  • 模型 (可调用对象) – 要优化的模块/函数

  • fullgraph (bool) – 如果为False(默认),torch.compile尝试在函数中发现可编译的区域进行优化。如果为True,则要求整个函数能够捕获到一个单一的图中。如果无法实现(即存在图中断),则会引发错误。

  • 动态 (布尔值) – 使用动态形状追踪。当此值为True时,我们将尝试预先生成尽可能动态的内核,以避免在尺寸变化时重新编译。这可能并不总是有效,因为某些操作/优化将强制进行专门化;使用TORCH_LOGS=dynamic来调试过度专门化。当此值为False时,我们将永远不会生成动态内核,我们总是会进行专门化。默认情况下(无),我们会自动检测是否发生了动态变化,并在重新编译时编译一个更动态的内核。

  • backend (strCallable) –

    要使用的后端

    • “inductor” 是默认后端,它在性能和开销之间提供了良好的平衡

    • 可以通过 torch._dynamo.list_backends() 查看非实验性的内置后端

    • 可以通过 torch._dynamo.list_backends(None) 查看实验性或调试性的内置后端

    • 要注册一个自定义的外部后端:https://pytorch.org/docs/main/compile/custom-backends.html

  • 模式 (字符串) –

    可以是“默认”、“减少开销”、“最大自动调优”或“最大自动调优无CUDA图”

    • “默认”是默认模式,它在性能和开销之间提供了良好的平衡

    • “减少开销”是一种减少Python与CUDA图开销的模式, 适用于小批量数据。减少开销可能会以更多内存使用为代价, 因为我们将会缓存调用所需的临时内存,以便在后续运行中不需要重新分配。 减少开销并不保证一定有效;目前,我们仅减少不修改输入的CUDA图的开销。 在其他情况下,CUDA图可能不适用;使用TORCH_LOG=perf_hints进行调试。

    • “最大自动调优”是一种利用基于Triton的矩阵乘法和卷积的模式 它默认启用CUDA图。

    • “最大自动调优无CUDA图”是一种类似于“最大自动调优”但不含CUDA图的模式

    • 要查看每个模式设置的确切配置,可以调用torch._inductor.list_mode_options()

  • options (dict) –

    传递给后端的选项字典。一些值得尝试的选项包括

    • epilogue_fusion 它将点操作融合到模板中。需要同时设置 max_autotune

    • max_autotune 它将进行分析以选择最佳的矩阵乘法配置

    • fallback_random 在调试精度问题时非常有用

    • shape_padding 它将矩阵形状填充以更好地对齐GPU上的负载,特别是对于张量核心

    • triton.cudagraphs 它将减少使用CUDA图时的Python开销

    • trace.enabled 这是最有用的调试标志,可以打开

    • trace.graph_diagram 它将在融合后显示您的图形的图片

    • 对于inductor,您可以通过调用 torch._inductor.list_options() 查看它支持的完整配置列表

  • 禁用 (布尔值) – 将 torch.compile() 变为无操作以进行测试

Return type

可调用

示例:

@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def foo(x):
    return torch.sin(x) + torch.cos(x)
优云智算