多头、多查询和组查询注意力

本文详细介绍了在TensorRT-LLM中为自回归GPT类模型实现多头注意力(MHA)、多查询注意力(MQA)和组查询注意力(GQA)的方法。快速回顾一下,多头注意力是在Attention Is All You Need文章中描述的批处理矩阵乘法、softmax和另一个批处理矩阵乘法的序列。多查询注意力(MQA)组查询注意力(GQA)是MHA的变体,它们使用的K/V头数量少于查询头的数量。TensorRT-LLM、MHA、MQA和GQA由操作符tensorrt_llm.functional.gpt_attention实现。

重要提示

如下所述,当前实现支持两种输入模式:填充模式和非填充模式(打包模式)。由于打包模式总是比填充模式更节省内存且速度更快,未来可能会移除对填充模式的支持

填充和打包的张量

在TensorRT-LLM中,GPT注意力算子支持两种不同类型的QKV输入:填充和打包(即未填充)输入。模式由tensorrt_llm.plugin中定义的全局配置参数remove_input_padding决定。

当启用填充时(即remove_input_paddingFalse),比max_sequence_length短的序列将被填充到该最大长度。这可能会导致过多的内存消耗以及在填充标记上不必要的计算(在MHA块周围的各种矩阵乘法中)。

为了解决这个问题,TensorRT-LLM 支持一种无填充的模式,其中不同的标记被打包在一起,用户向操作符提供一个包含不同序列长度的1D张量。建议用户始终使用打包模式(并且未来可能会移除对填充模式的支持)。

上下文和生成阶段

GPT注意力操作符封装了自回归模型(如GPT)中上下文和生成阶段的不同实现。

上下文阶段

如果context_fmha_type设置为disabled(参考 tensorrt_llm.plugin), 实现将映射到一系列GPU内核,这些内核将在调用softmax操作符之前将中间Q*K^T张量存储在内存中。这是最慢的方法,并且内存占用显著(与序列长度的平方成正比)。

否则,如果 context_fmha_type 设置为 enabledenabled_with_fp32_acc(在第一个批量矩阵乘法中强制使用 FP32 累加),该函数将触发一个使用单个内核执行 MHA/MQA 块的内核。对于短序列,该内核使用 MHA/MQA 的普通实现。对于较长的序列,该内核使用 Flash Attention 算法,如 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessFlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning 中所述。

目前,该实现触发了额外的内核,这些内核对元素进行预处理(如RoPE)并填充KV缓存(见下文)。在未来的版本中,计划减少此类内核的数量,以提高整体性能。

FP8 上下文 FMHA

当FP8量化被激活时,通过启用FP8上下文FMHA(use_fp8_context_fmha = enable),注意力可以进一步加速。

FP8 分页上下文 FMHA 也支持 fp8 量化工作流。 你需要同时指定 use_fp8_context_fmha = enableuse_paged_context_fmha = enable

请注意,这是一个仅在Hopper上支持的实验性功能。 如果您注意到准确率显著下降,建议禁用它。

生成阶段

生成阶段是通过在TensorRT-LLM中使用一个称为掩码多头注意力的单一内核来实现的。该内核能够即时对Q、K和V元素进行预处理:添加QKV偏置,应用RoPE,并执行去量化和量化。TensorRT-LLM将在未来的版本中继续添加(或启用)更多功能。例如,启用对IA3的支持。

掩码的MHA内核有一个特殊版本,在GPU占用率较低的情况下,将工作分配到GPU上的多个CUDA线程块上。从TRT-LLM 0.13开始,默认启用这种称为多块的模式,并且可以在运行时使用--multi_block_mode=False来禁用。建议用户在模型的批量大小和头数都相对较小的情况下测试该模式。在这种情况下,“小”的确切定义将取决于GPU的型号,并且难以预测,但作为一个经验法则,当batch_size * num_heads小于GPU上的多处理器数量时,值得测试该模式(随着更多研究的进行和软件的改进,这个建议可能会在未来发生变化)

请注意,即使启用了多块模式,注意力操作符也不会立即触发GPU内核的多块版本。需要达到一定数量的令牌(输入+生成)才能使多块版本比使用每个头一个CUDA线程块的“普通”实现更高效。这是由内部启发式控制的。

另一个需要注意的是,由于掩码MHA内核使用的共享内存大小与序列长度成正比,因此在未启用多块模式的情况下,可能会出现GPU的共享内存不足的情况。为了使掩码MHA内核在这些情况下工作,强制启用多块模式并打印警告日志。

XQA优化

另一种针对MQA/GQA在生成阶段的优化称为XQA优化。 这仍然是一个实验性功能,支持有限的配置。LLAMA2 70B 是它支持的模型之一。

XQA优化的支持矩阵:

  • FP16 / BF16 计算数据类型。

  • FP16 / BF16 / FP8 / INT8 KV缓存数据类型。

  • 分页键值缓存(每块64 / 128个令牌)。

这是默认启用的。要禁用此功能,您需要在构建引擎时使用标志--disable_xqa。请注意,还会使用启发式算法来决定是使用XQA内核还是掩码MHA内核以获得更好的性能。这意味着即使未设置--disable_xqa,也可能不会使用XQA内核。如果您希望在可能的情况下始终使用该内核,可以设置TRTLLM_FORCE_XQA=1以在支持模型配置时强制使用XQA内核。详细的受支持配置可以在cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h中的DecoderXQARunner类的shouldUse函数中找到。

飞行中的批处理

TensorRT-LLM 支持请求的飞行中批处理(也称为连续批处理或迭代级批处理),以提高服务吞吐量。通过此功能,上下文阶段的序列可以与生成阶段的序列一起处理。该技术的目的是更好地交错请求以减少延迟,并更好地利用 GPU。出于效率原因(1),飞行中批处理的支持要求输入张量被打包(无填充)

在当前实现中,输入张量中处于上下文阶段的序列必须位于生成阶段的序列之前。例如,对于序列 S0S1S2,如果 S0S2 处于上下文阶段(而 S1 处于生成阶段),则 S0S2 的标记必须在输入张量中出现在 S1 的标记之前。此约束在未来的版本中可能会或可能不会放宽。

(1) 在生成阶段,对仅包含单个标记的序列进行填充,使其达到最大输入序列的长度,是对资源的低效使用

分块上下文

在原始状态下,常见的行为是一次性处理所有上下文标记。此功能将上下文分成几个块。通过这种方式,在生成阶段可以将上下文块与更多标记一起批处理,预计将提高总吞吐量。分块上下文还消除了对输入长度的限制。要启用此功能,还需要启用FMHA分页kv-cache。除了最后一个,上下文块的大小需要是kv-cache块大小的整数倍。有关使用情况,请参考性能最佳实践

KV 缓存

在生成阶段,一个常见的优化是向MHA内核提供一个缓存,该缓存包含已经计算过的过去K和V元素的值。该缓存被称为KV缓存。TensorRT-LLM使用该技术来加速其生成阶段。在TensorRT-LLM中,每个Transformer层都有一个KV缓存,这意味着模型中有多少层就有多少KV缓存。当前版本的TensorRT-LLM支持两种不同类型的KV缓存:连续分页 KV缓存。

连续KV缓存

连续的KV缓存是一个单一的张量。它的形状是:

[max_batch_size * max_beam_width, 2, num_heads, max_seqlen, hidden_dim_per_head].

当序列比最大序列长度短时,该实现使用的内存比实际需要的多得多(即使它们在生成许多输出标记后接近限制,也可能需要很多步骤才能达到那个点)。

分页键值缓存

分页KV缓存将KV缓存分解为块,这些块在处理过程中由缓存管理器分配给不同的请求。该缓存管理器跟踪序列,从池中分配新块,并在需要时回收这些块。请参阅tensorrt_llm.runtime.KVCacheManager的简化实现。更高效的C++实现包含在Batch Manager中。

INT8/FP8 KV 缓存

在当前的实现中,即使网络的其余部分运行在INT8或FP8,GPT注意力算子仍然使用FP32、FP16和BFloat16输入和输出。然而,TensorRT-LLM支持INT8和FP8(kv_cache_quant_mode=QuantMode.INT8_KV_CACHEkv_cache_quant_mode=QuantMode.FP8_KV_CACHE)KV缓存。

GPT注意力操作符填充KV缓存。当启用INT8或FP8 KV缓存时,输入值必须使用缩放因子量化为8位。对于量化,缩放因子存储在kv_cache_scaling_factor张量中。其形状为[1],当前版本仅支持每张量量化。量化使用逆缩放,因为它在插件中执行fp_value * (1.0 / kv_cache_scaling_factor)的乘法。

在生成过程中,从缓存中读取的值会在MHA/MQA内核中实时反量化,反量化可以描述为 quantized_value * kv_cache_scaling_factor

滑动窗口注意力,循环(滚动缓冲区)KV缓存

TensorRT-LLM 有一个名为 Cyclic KV Cache 的功能,它将 kv 缓存视为一个循环缓冲区。这意味着它只存储最后 N 个 token 的 kv 缓存,其中 N 由 GenerationSession.setup 中的 max_attention_window_size 参数决定。你可以在 run.pysummarize.py 文件中看到这方面的示例。当缓存满时,新 token 的 kv 缓存将覆盖“最近最少使用”的缓存。

在上下文阶段,如果输入长度超过max_attention_window_sizeSliding Window Attention将被激活。这与 sliding window_size的功能相同。

此功能有助于在处理非常长的序列时减少kv缓存的内存占用。

该功能允许每层使用不同的max_attention_window_size值,也支持此功能。要利用此功能,只需在使用python运行时会话时,向GenerationSession.setup提供一个int32 torch.Tensorlist,或者在使用cpp运行时向KvCacheConfig提供一个向量。如果提供的元素数量少于层数,则提供的张量/列表/向量将被重复多次以达到层数,然后保存为一个新的张量。此张量将作为max_attention_window_size的缓冲区,为每层设置唯一值。然而,需要注意的是,kv缓存的内存分配仍然依赖于缓冲区的最大值。

_请注意,循环kv缓存功能目前不适用于光束搜索,因为上下文kv缓存在光束之间共享。

流式LLM

StreamingLLM 功能使用窗口注意力机制在长文本上执行高效且稳定的 LLM,这意味着只需要在 KV 缓存中存储 N 个令牌。与 TensorRT-LLM 中的循环 KV 缓存功能类似,max_attention_window_size 参数用于确定 N。与循环 KV 缓存功能不同的是,前 S 个令牌,称为 sink tokens,始终保留在注意力窗口中,其中 SGenerationSession.setup 中的 sink_token_length 参数确定。但在上下文阶段,StreamingLLM 的官方实现中自注意力是密集的,它使用所有令牌进行计算,并且只将 N 个令牌保存到 KV 缓存中。

此外,StreamingLLM中的相对位置嵌入也发生了变化。 在确定相对距离并向标记添加位置信息时, StreamingLLM使用缓存中的位置,而不是原始文本中的位置。

streamingllm 标志用于启用此功能。

输入QKV张量

输入的QKV张量在隐藏状态的投影后打包了Q、K和V张量(沿最后一个维度连接)。它是一个3D张量。RoPE和量化为INT8或FP8(在需要时)由GPT注意力操作符执行。

在填充模式下,其形状为[batch_beam_size, max_seqlen, 3 * hidden_dim] 其中batch_beam_size是上下文阶段的批次大小(序列数量),在生成阶段是批次大小乘以束宽。 在填充模式下不支持每个序列具有不同的束宽。

在打包模式下,其形状为[num_tokens, 3 * hidden_dim],其中 num_tokens是批次中令牌的总数。对于上下文阶段的序列,一个序列的令牌数对应于其输入 长度(即使对于波束搜索,波束宽度大于1)。对于 生成阶段的序列,每个序列有beam_width个令牌。每个序列的 波束宽度可以不同。

换句话说,计算令牌数量的伪代码如下:

num_tokens = 0

# Add the length of each sequence in context phase.
for seq in context_phase:
    num_tokens += seq.length

# Add the width of the beam for each sequence in generation phase.
for seq in generation_phase:
    num_tokens += seq.beam_width

旋转位置嵌入 (RoPE)

GPT注意力操作可以执行旋转位置嵌入(RoPE)的计算。当启用该操作时,rotary_embedding_dim被设置为大于0的值,并与其他操作融合。GPT操作符通过将position_embedding_type设置为PositionEmbeddingType.rope_gpt_neoxPositionEmbeddingType.rope_gptj来支持GPT-NeoX和GPT-J形式的RoPE。

ALiBi

GPT注意力算子可以将ALiBi应用于Q*K^T乘积的结果。偏差是从优化内核中的ALiBi斜率动态计算的。

缩放因子

在MHA中,Q*K^T 乘积的输出通过一个常数进行缩放,该常数的计算方式如下:

norm_factor = 1.f / (q_scaling * sqrt(head_size)).

交叉注意力

在GPT风格的仅解码器模型所需的自注意力机制MHA之上,gpt_attention 还支持交叉注意力。

这使得gpt_attention能够在更广泛的方面作为通用解码器组件使用。例如,编码器-解码器模型使用gpt_attention在其解码器中同时发出自注意力和交叉注意力模块。

相对注意力偏差 (RAB)

相对注意力偏差(RAB)是一种相对位置建模方法,根据相对位置添加一个注意力偏差(Q*K^T+bias)。RAB 是一种轻量级的方法,用于包含相对位置的信息,并在流行的编码器-解码器模型 T5 以及 T5 系列中的其他模型中使用。

RAB 支持两种模式:i) 常规模式,用户传入在 MHA 之前计算好的相对注意力偏差。ii) 隐式模式,在 MHA 中动态计算相对注意力偏差。隐式模式适用于相对注意力偏差太大无法放入内存的情况,可以通过传入 max_distance 来开启。