• Tutorials >
  • Using User-Defined Triton Kernels with torch.compile
Shortcuts

使用用户定义的Triton内核与torch.compile

创建于:2024年4月19日 | 最后更新:2024年10月16日 | 最后验证:2024年11月5日

作者: Oguz Ulgen

用户定义的Triton内核可用于优化模型计算的特定部分。这些内核是用Triton语言编写的,该语言旨在更容易实现硬件性能的峰值。通过使用用户定义的Triton内核与torch.compile,您可以将这些优化计算集成到您的PyTorch模型中,可能会显著提高性能。

本食谱演示了如何使用用户定义的Triton内核与torch.compile

先决条件

在开始这个配方之前,请确保您具备以下条件:

import torch
from torch.utils._triton import has_triton

基本用法

在这个例子中,我们将使用Triton文档中的一个简单的向量加法内核 与torch.compile。 作为参考,请参阅Triton文档

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Y:      tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0')
is equal to
tensor([ 0.3332,  2.0532, -0.8895,  1.6057], device='cuda:0')

高级用法

Triton的自动调优功能是一个强大的工具,它能自动优化Triton内核的配置参数。它会探索一系列可能的配置,并选择为您的特定用例提供最佳性能的配置。

当与torch.compile一起使用时,triton.autotune可以帮助确保您的PyTorch模型尽可能高效地运行。这里是一个使用torch.compiletriton.autotune的示例。

注意

torch.compile 仅支持配置和关键参数到 triton.autotune

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-0.5187,  1.2268,  0.6255, -0.9117], device='cuda:0')
Y:      tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')

可组合性与限制

截至 PyTorch 2.3,torch.compile 中对用户自定义 Triton 内核的支持包括动态形状、torch.autograd.Function、JIT 诱导器和 AOT 诱导器。您可以结合使用这些功能来构建复杂的高性能模型。

然而,有一些需要注意的限制:

  • 张量子类: 目前,不支持张量子类和其他高级功能。

  • Triton 特性: 虽然 triton.heuristics 可以单独使用或在 triton.autotune 之前使用,但不能在 triton.autotune 之后使用。这意味着如果 triton.heuristicstriton.autotune 要一起使用,必须首先使用 triton.heuristics

结论

在本教程中,我们探讨了如何利用用户定义的Triton内核与torch.compile。我们深入研究了简单向量加法内核的基本用法以及涉及Triton自动调优功能的高级用法。我们还讨论了用户定义的Triton内核与其他PyTorch功能的可组合性,并强调了一些当前的限制。

优云智算