调试Triton

本教程为调试Triton程序提供指导。 主要面向Triton用户进行文档记录。 对探索Triton后端(包括MLIR代码转换和LLVM代码生成)感兴趣的开发者, 可以参考这个章节来了解调试选项。

使用Triton的调试操作

Triton包含四个调试运算符,允许用户检查和查看张量值:

  • static_printstatic_assert 用于编译时调试。

  • device_printdevice_assert 用于运行时调试。

device_assert 仅在 TRITON_DEBUG 设置为 1 时执行。 其他调试运算符的执行不受 TRITON_DEBUG 值的影响。

使用解释器

解释器是一个简单实用的工具,用于调试Triton程序。 它允许Triton用户在CPU上运行Triton程序,并检查每个操作的中间结果。 要启用解释器模式,请将环境变量TRITON_INTERPRET设置为1。 此设置会使所有Triton内核绕过编译阶段,转而由解释器使用Triton操作的numpy等效实现进行模拟。 解释器会顺序处理每个Triton程序实例,逐个执行操作。

使用解释器主要有三种方式:

  • 使用Python的print函数打印每个操作的中间结果。要检查整个张量,请使用print(tensor)。要检查idx处的单个张量值,请使用print(tensor.handle.data[idx])

  • 附加pdb用于逐步调试Triton程序:

    TRITON_INTERPRET=1 pdb main.py
    b main.py: number>
    r
    
  • 导入 pdb 包并在 Triton 程序中设置断点:

    import triton
    import triton.language as tl
    import pdb
    
    @triton.jit
    def kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
      pdb.set_trace()
      offs = tl.arange(0, BLOCK_SIZE)
      x = tl.load(x_ptr + offs)
      tl.store(y_ptr + offs, x)
    

限制

解释器存在几个已知的限制:

  • 它不支持对bfloat16数值类型的操作。要在bfloat16张量上执行操作,请使用tl.cast(tensor)将张量转换为float32

  • 它不支持间接内存访问模式,例如:

    ptr = tl.load(ptr)
    x = tl.load(ptr)
    

使用第三方工具

要在NVIDIA GPU上进行调试,compute-sanitizer是检查数据竞争和内存访问问题的有效工具。 使用时,只需在运行Triton程序的命令前添加compute-sanitizer即可。

要在AMD GPU上进行调试,您可以尝试使用LLVM AddressSanitizer for ROCm。

要详细可视化Triton程序中的内存访问,可以考虑使用triton-viz工具,该工具与底层GPU无关。