• Tutorials >
  • Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()
Shortcuts

通过使用嵌套张量和torch.compile()替换nn.Transformer来加速PyTorch Transformers

作者: Mikayla Gawarecki

注意

本教程目前要求您使用PyTorch的夜间构建版本。

What you will learn
  • 了解PyTorch提供的用于构建自定义Transformer层的底层构建块(嵌套张量、scaled_dot_product_attentiontorch.compile()FlexAttention

  • 了解以上内容如何以MultiHeadAttention为例改进内存使用和性能

  • 使用上述构建块探索高级自定义

Prerequisites
  • PyTorch v.2.6.0 或更高版本

在过去的几年里,PyTorch 团队开发了各种低级功能,这些功能组合起来可以创建多种 transformer 变体。这些包括:

  • 使用torch.jagged布局的嵌套张量(也称为NJTs)

  • scaled_dot_product_attention

  • torch.compile()

  • FlexAttention

本教程将简要概述上述技术,并展示如何将它们组合起来,以产生灵活且高性能的变压器层,从而改善用户体验。

可以观察到,torch.nn 模块目前提供了各种与 Transformer 相关的层。 特别是,它包括 TransformerEncoderLayerTransformerEncoderTransformerDecoderLayerTransformerDecoderTransformerMultiheadAttention。这一系列的层最初是根据 Attention is All You Need 论文实现的。本教程中讨论的组件在用户体验、灵活性和性能方面提供了比现有 nn 层更好的改进。

这个教程适合我吗?

如果你想知道torch库提供了哪些构建模块来编写你自己的transformer层以及最佳实践,那么你来对地方了。请继续阅读!

如果你正在寻找一个流行的transformer架构的开箱即用实现,请注意有许多开源库提供了这些实现,包括:

如果您只对高效的注意力分数修改感兴趣,请查看包含FlexAttention博客gym of masks

介绍构建模块

首先,我们将简要介绍引言中提到的四种技术

嵌套张量扩展了常规密集张量的形状,允许使用相同的张量用户体验来表示不规则大小的数据。在变压器的上下文中,我们可以将嵌套张量视为表示可变序列长度的工具。它们消除了容易出错的显式填充和掩码操作的需要(想想key_padding_masknn.MultiHeadAttention中的使用)。

scaled_dot_product_attention 是一个用于 \(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的原语,它会分派到该操作符的融合实现或回退实现。它在即时模式下开箱即用(即 PyTorch 的默认模式,操作在遇到时立即执行),并且与 torch.compile() 无缝集成。从 2.6 版本开始,它还将原生支持分组查询注意力。

torch.compile() 是在2.0版本中引入的编译器,能够捕获PyTorch代码的图并对其进行各种优化,例如将操作序列融合在一起。使用torch.jagged布局的嵌套张量和scaled_dot_product_attention与编译无缝协作。在transformers的上下文中,使用编译与嵌套张量和SDPA的附加价值在于,编译可以消除在eager模式下看到的框架开销,并将transformers中的操作序列(如投影和激活)融合在一起。

FlexAttention 是一个允许用户在 softmax 操作之前修改注意力分数的原语。它概括了上述 scaled_dot_product_attention 的加法 B 项,允许进行任意计算。它需要编译以实现良好的性能。

上述构建模块是“您所需要的全部”(截至2024年10月)

本节的主要前提是,大多数变换器的变体都是GPT风格的,由嵌入层、位置编码、注意力块和前馈网络等层组成。如果我们试图对这一领域的差异进行分类,可能会得出类似以下的内容:

  1. 层类型(激活函数如SwiGLU和其他,归一化函数如RMSNorm和其他,位置编码,如正弦、旋转等。)

  2. 层排序,例如在哪里应用规范和位置编码。

  3. 对注意力分数的修改,例如ALiBi、相对位置偏差等。

在预编译环境中,您可能会编写一个自定义转换器并注意到它功能正确但速度较慢。为了解决这个问题,您可能会为特定的操作系列开发一个自定义融合内核。在编译环境中,您可以简单地执行初始步骤,然后编译并从改进的性能中受益。

多头注意力机制

请记住,MultiheadAttention 接收查询、键和值,并包含一个输入投影、一个 scaled_dot_product_attention 操作符和一个输出投影。我们在这里想要展示的主要收获是,当我们用嵌套张量替换填充/掩码输入时所带来的改进。改进有三方面:

  • 用户体验 请记住,nn.MultiheadAttention 需要 querykeyvalue 是密集的 torch.Tensors。它还提供了一个 key_padding_mask,用于屏蔽由于批次中不同序列长度而产生的 key 中的填充标记。由于 nn.MHA 中没有 query_padding_mask,用户必须注意适当地屏蔽/切片 输出以考虑查询序列长度。NestedTensor 完全消除了对这种容易出错的填充掩码的需求。

  • 内存 与其实例化一个密集的 [B, S, D] 张量并附带一个 [B, S] 填充掩码(其中 B 是批量大小,S 是批次中的最大序列长度, D 是嵌入大小),嵌套张量允许您清晰地表示不同序列长度的批次。因此,输入和 中间激活将使用更少的内存。

  • 性能 由于填充不会被具体化,并且跳过了对填充的不必要计算,因此性能和内存使用情况得到了改善。

我们将通过在Nested Tensor教程中的MultiheadAttention层的基础上进行演示,并将其与nn.MultiheadAttention层进行比较。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """
    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        nheads: int,
        dropout: float = 0.0,
        bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
          self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
        else:
          self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
          self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
          self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                attn_mask=None,
                is_causal=False) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
            key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
            value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
            attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)
            else:
                q_weight, k_weight, v_weight = torch.chunk(self.packed_proj.weight, 3, dim=0)
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(self.packed_proj.bias, 3, dim=0)
                else:
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = F.linear(query, q_weight, q_bias), F.linear(key, k_weight, k_bias), F.linear(value, v_weight, v_bias)

        else:
            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

实用工具

在本节中,我们包含了一个实用工具,用于使用Zipf分布生成半真实的数据以模拟句子长度。这用于生成嵌套的查询、键和值张量。我们还包含了一个基准测试工具。

import numpy as np

def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)

# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    if query_seq_len_1:
        query = torch.nested.nested_tensor([
            torch.randn(1, E_q, dtype=dtype, device=device)
            for l in sentence_lengths
        ], layout=torch.jagged)
    else:
        query = torch.nested.nested_tensor([
            torch.randn(l.item(), E_q, dtype=dtype, device=device)
            for l in sentence_lengths
        ], layout=torch.jagged)

    key = torch.nested.nested_tensor([
        torch.randn(s.item(), E_k, dtype=dtype, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    value = torch.nested.nested_tensor([
        torch.randn(s.item(), E_v, dtype=dtype, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    return query, key, value, sentence_lengths

import timeit
import math

def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

我们现在将展示在MultiheadAttention层中使用嵌套张量并结合自注意力机制的性能改进。我们将此与传统的nn.MultiheadAttention结合填充和掩码进行比较。

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device='cuda'
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}")
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

torch.manual_seed(6)
mha_layer = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device='cuda')
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda')

# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach())
mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach())
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(vanilla_mha_layer.in_proj_bias.clone().detach())

new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf'))
for i, s in enumerate(sentence_lengths):
    attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N*nheads, S, S)

vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(padded_query,
                                          padded_query,
                                          padded_query,
                                          attn_mask=attn_mask,
                                          key_padding_mask=src_key_padding_mask,
                                          need_weights=False,
                                          is_causal=True)

# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer,
                                                                padded_query,
                                                                padded_query,
                                                                padded_query,
                                                                key_padding_mask=src_key_padding_mask,
                                                                need_weights=False,
                                                                attn_mask=attn_mask,
                                                                is_causal=True)

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item())
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")

作为参考,以下是一些在A100上的示例输出:

padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB

我们也可以看到反向传播的相同情况

for i, entry_length in enumerate(sentence_lengths):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(lambda : padded_result.sum().backward())
_, nested_bw_time, nested_bw_peak_mem = benchmark(lambda : padded_nested_result.sum().backward())

print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB")

print("Difference in out_proj.weight.grad", (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad).abs().max().item())
print("Difference in packed_proj.weight.grad", (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad).abs().max().item())
print("Difference in out_proj.bias.grad", (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad).abs().max().item())
print("Difference in packed_proj.bias.grad", (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad).abs().max().item())

A100上的示例输出:

padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125

GPT风格层

一个基本的GPT风格变换器层由一个因果自注意力层和一个带有跳跃连接的前馈网络(FFN)组成。使用上面的MultiheadAttention层实现这一点相当直接,并且与带有is_causal=Truenn.TransformerEncoderLayer给出等效的结果。

我们展示了实现其余nn层的示例 这里 但为了简洁起见,本教程中省略了这些内容。

更进一步

到目前为止,我们已经展示了如何实现一个遵循传统nn.MultiheadAttention的高性能MultiheadAttention层。回到我们对Transformer架构修改的分类,记得我们将修改分为层类型、层顺序和对注意力分数的修改。我们相信改变层类型和层顺序(例如将LayerNorm替换为RMSNorm)是相当直接的。

在本节中,我们将讨论使用上述构建块的各种功能,包括以下内容:

  • 交叉注意力

  • 完全屏蔽的行不再导致NaN

  • 修改注意力分数:使用FlexAttention和NJT的ALiBi

  • 打包投影

交叉注意力

交叉注意力是一种注意力形式,其中查询和键/值张量来自不同的序列。

一个例子是在nn.TransformerDecoderLayer中,查询来自解码器,而键/值来自编码器。

上述的MultiheadAttention层很好地推广到了这种情况,其中查询和键/值都是嵌套张量。

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)

print(f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}")
print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}")
out = new_mha_layer(query, key, value, is_causal=False)

如上所述,我们可以将其与普通的编译nn.MultiheadAttention进行比较。

torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]

# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(padded_query,
                                          padded_key,
                                          padded_value,
                                          key_padding_mask=key_padding_mask,
                                          need_weights=False,
                                          is_causal=False)

nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, key, value, is_causal=False)
(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer,
                                                                padded_query,
                                                                padded_key,
                                                                padded_value,
                                                                key_padding_mask=key_padding_mask,
                                                                need_weights=False,
                                                                is_causal=False)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item())
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB")

A100上的示例输出:

Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB

完全屏蔽的行不再导致NaN

长期以来,nn.MultiheadAttentionscaled_dot_product_attention 存在一个问题,即如果一行被完全屏蔽,注意力层的输出 将会是 NaN。请参阅 issue。 这是因为在空集上的 softmax 是未定义的。

感谢这个PR,这种情况不再存在。相反,完全屏蔽的行在scaled_dot_product_attention中。对于nn.MHA不使用“快速路径”的情况,这也将适用。

强烈建议使用自定义的MHA层与NJT结合,而不是使用现有的nn.MultiheadAttention中的“快速路径”,因为NJT能够适当地建模不规则性,从而能够正确表达空序列。

FlexAttention + NJT

NJT 还与 FlexAttention 模块组合使用。这是对 MultiheadAttention 层的一种泛化,允许对注意力分数进行任意修改。下面的示例使用了来自 attention gymalibi_mod,它实现了 ALiBi,并将其与嵌套的输入张量一起使用。

from torch.nn.attention.flex_attention import flex_attention

def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
    Args:
        H: number of heads
    Returns:
        alibi_bias: alibi bias score_mod
    """
    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias
    return alibi_mod

query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = (
    query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = (
    value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

此外,还可以通过create_nested_block_mask函数将FlexAttentionblock_mask工具与NJT结合使用。这对于利用掩码的稀疏性来加速注意力计算非常有用。特别是,该函数为NJT中所有变长序列组合成一个“堆叠序列”创建了一个稀疏块掩码,同时正确地屏蔽了序列间的注意力。在下面的示例中,我们展示了如何使用此工具创建因果块掩码。

from torch.nn.attention.flex_attention import create_nested_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = (
    query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = (
    value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
out_flex = flex_attention(query, key, value, block_mask=block_mask)

打包投影

打包投影是一种技术,它利用了当投影的输入(矩阵乘法)相同(自注意力)时,我们可以将投影权重和偏置打包成单个张量的事实。这在单个投影受内存限制而不是计算限制时特别有用。这里我们将展示两个示例:

  • MultiheadAttention的输入投影

  • Transformer层前馈网络中的SwiGLU激活

MultiheadAttention的输入投影

在进行自注意力时,querykeyvalue是相同的张量。每个张量都通过一个Linear(E_q, E_total)层进行投影。相反,我们可以将其打包到一个层中,这就是我们在上面的MultiheadAttention层中所做的。

让我们比较打包投影与常规方法的性能:

class InputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)

class PackedInputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)

    def forward(self, query):
        return torch.chunk(self.packed_proj(query), 3, dim=-1)

B, D, dtype = 256, 8192, torch.bfloat16

torch.set_float32_matmul_precision('high')
in_proj = torch.compile(InputProjection(D, D, device='cuda', dtype=torch.bfloat16))
packed_in_proj = torch.compile(PackedInputProjection(D, D, device='cuda', dtype=torch.bfloat16))

q, _, _, sequence_lengths = gen_batch(B, D, D, D, device='cuda', dtype=torch.bfloat16)

# warmup
in_proj(q)
packed_in_proj(q)

# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x")

Transformer层的SwiGLU前馈网络

SwiGLU(Swish-Gated Linear Unit)是一种非线性激活函数,在Transformer层的前馈网络(例如Llama)中越来越受欢迎。使用SwiGLU激活的前馈网络定义为:

class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

实现此目的的另一种方法是使用打包投影

class PackedSwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)

    def forward(self, x):
        x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
        return self.w2(F.silu(x1) * x3)

我们可以比较两种实现的性能如下 根据您的硬件,您可能会看到不同的结果。在A100上,我看到 D=128时速度提升了1.12倍。

D = 128

swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16))
packed_swigluffn = torch.compile(PackedSwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16))

q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
swigluffn(q)
packed_swigluffn(q)

# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x")

扩展示例

我们计划更新本教程,以展示更多关于如何使用各种高性能构建模块的示例,例如KV缓存、分组查询注意力等。此外,还有一些很好的示例展示了如何使用各种高性能构建模块来实现不同的transformer架构。一些示例包括

结论

在本教程中,我们介绍了PyTorch提供的用于编写Transformer层的低级构建块,并展示了如何组合它们的示例。我们希望本教程能够向读者展示,PyTorch用户可以轻松实现灵活且高性能的Transformer层。

脚本总运行时间: ( 0 分钟 0.000 秒)

Gallery generated by Sphinx-Gallery

优云智算