分页注意力机制¶
警告
本文档是基于vLLM原始论文的历史文献,已不再描述当前vLLM中使用的代码。
目前,vLLM使用其自研的多头查询注意力内核(csrc/attention/attention_kernels.cu)。该内核设计兼容vLLM的分页KV缓存机制,其中键值缓存存储在独立块中(请注意此处的块概念与GPU线程块不同。因此在后续文档中,我将vLLM分页注意力块称为"block",而将GPU线程块称为"thread block")。
为了实现高性能,该内核依赖于一种特殊设计的内存布局和访问方法,特别是当线程从全局内存读取数据到共享内存时。本文档旨在逐步提供对内核实现的高层次解释,帮助那些希望了解vLLM多头查询注意力内核的人。阅读完本文档后,用户可能会更容易理解并跟进实际实现。
请注意,本文档可能未涵盖所有细节,例如如何计算对应数据的正确索引或点乘实现。但在阅读本文档并熟悉高层逻辑流程后,您应该能更容易阅读实际代码并理解具体细节。
输入¶
内核函数接收当前线程执行分配工作所需的参数列表。其中最重要的三个参数是输入指针q、k_cache和v_cache,它们指向需要读取和处理的全局内存中的查询(query)、键(key)和值(value)数据。输出指针out指向结果应写入的全局内存。这四个指针实际上引用的是多维数组,但每个线程仅访问分配给它的数据部分。为简化说明,此处省略了所有其他运行时参数。
template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
__device__ void paged_attention_kernel(
... // Other side args.
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
... // Other side args.
)
函数签名上方还有一系列在编译时确定的模板参数。scalar_t表示查询、键和值数据元素的数据类型,例如FP16。HEAD_SIZE表示每个头中的元素数量。BLOCK_SIZE指每个块中的令牌数量。NUM_THREADS表示每个线程块中的线程数。PARTITION_SIZE代表张量并行GPU的数量(为简化起见,我们假设该值为0且张量并行已禁用)。
通过这些参数,我们需要执行一系列准备工作。这包括计算当前头索引、块索引以及其他必要的变量。不过目前我们可以暂时忽略这些准备工作,直接进入实际计算阶段。一旦我们掌握了整个流程,理解这些内容会变得更加容易。
概念¶
在我们深入计算流程之前,我想先介绍几个后续章节需要用到的概念。不过,如果您遇到任何难以理解的术语,也可以先跳过本节稍后再回看。
- 序列: 一个序列代表一个客户端请求。例如,
q指向的数据形状为[num_seqs, num_heads, head_size]。这表示q总共指向了num_seqs个查询序列数据。由于这个内核是单查询注意力内核,每个序列只有一个查询token。因此,num_seqs等于批次中处理的所有token总数。 - 上下文: 上下文由序列中已生成的标记组成。例如,
["What", "is", "your"]是上下文标记,而输入查询标记是"name"。模型可能会生成标记"?"。 - Vec: Vec是一个元素列表,这些元素会被一起获取和计算。对于查询(query)和键(key)数据,vec大小(
VEC_SIZE)的设定确保每个线程组可以一次性获取并计算16字节的数据。对于值(value)数据,vec大小(V_VEC_SIZE)的设定确保每个线程可以一次性获取并计算16字节的数据。例如,如果scalar_t是FP16(2字节)且THREAD_GROUP_SIZE为2,那么VEC_SIZE将为4,而V_VEC_SIZE将为8。 - 线程组: 线程组是由少量线程(
THREAD_GROUP_SIZE)组成的小组,每次获取并计算一个查询token和一个键token。每个线程仅处理token数据的一部分。一个线程组处理的总元素数称为x。例如,如果线程组包含2个线程且头大小为8,那么线程0处理索引为0、2、4、6的查询和键元素,而线程1处理索引为1、3、5、7的元素。 - 块(Block): vLLM中的键值缓存数据被分割成多个块。每个块存储固定数量(
BLOCK_SIZE)的令牌在一个头中的数据。每个块可能只包含整个上下文令牌的一部分。例如,如果块大小为16且头大小为128,那么对于一个头来说,一个块可以存储16 * 128 = 2048个元素。 - Warp: 一个warp是由32个线程(
WARP_SIZE)组成的组,它们在流式多处理器(SM)上同时执行。在这个内核中,每个warp一次处理一个查询token与一个完整块中所有key token之间的计算(它可能在多次迭代中处理多个块)。例如,如果有4个warps和6个块用于一个上下文,分配方式将是:warp 0处理第0、4块,warp 1处理第1、5块,warp 2处理第2块,warp 3处理第3块。 - 线程块: 线程块是一组可以访问相同共享内存的线程(
NUM_THREADS)。每个线程块包含多个线程束(NUM_WARPS),在这个内核中,每个线程块处理一个查询token与整个上下文中所有key token之间的计算。 - Grid: 网格是线程块的集合,定义了该集合的形状。在此内核中,形状为
(num_heads, num_seqs, max_num_partitions)。因此,每个线程块仅处理一个注意力头、一个序列和一个分区的计算。
查询¶
本节将介绍查询数据如何在内存中存储以及各线程如何获取。如前所述,每个线程组获取一个查询令牌数据,而每个线程本身仅处理查询令牌数据的一部分。在每个warp内,所有线程组将获取相同的查询令牌数据,但会与不同的键令牌数据进行乘法运算。
每个线程定义自己的q_ptr,它指向全局内存中分配的查询令牌数据。例如,如果VEC_SIZE为4且HEAD_SIZE为128,则q_ptr指向包含总共128个元素的数据,这些元素被分成128 / 4 = 32个向量。
接下来,我们需要将q_ptr指向的全局内存数据读取到共享内存中,存储为q_vecs。需要注意的是,每个向量被分配到不同的行。例如,如果THREAD_GROUP_SIZE为2,线程0将处理第0行向量,而线程1处理第1行向量。通过这种方式读取查询数据,相邻线程(如线程0和线程1)可以读取相邻内存,实现内存合并以提高性能。
关键¶
与"查询"部分类似,本节介绍了键的内存布局和分配方式。虽然每个线程组在内核运行时仅处理一个查询令牌,但它可能通过多次迭代处理多个键令牌。同时,每个warp将在多次迭代中处理多个键令牌块,确保所有上下文令牌在内核运行后被整个线程组处理完毕。此处的"处理"指的是执行查询数据与键数据之间的点积运算。
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
与q_ptr不同,每个线程中的k_ptr在不同迭代时会指向不同的键令牌。如上所示,k_ptr根据k_cache在分配的块、分配的头和分配的令牌处指向键令牌数据。
上图展示了键值数据的内存布局。假设BLOCK_SIZE为16,HEAD_SIZE为128,x为8,THREAD_GROUP_SIZE为2,总共有4个warp。每个矩形代表一个注意力头中一个键值令牌的所有元素,这些元素将由一个线程组处理。左半部分显示warp 0的16个键值令牌数据块,右半部分表示其他warp或迭代中剩余的键值令牌数据。每个矩形内部共有32个向量(一个令牌的128个元素),将由2个线程(一个线程组)分别处理。
接下来,我们需要从k_ptr读取关键令牌数据,并将其作为k_vecs存储在寄存器内存中。我们使用寄存器内存来存储k_vecs,因为它只会被一个线程访问一次,而q_vecs会被多个线程多次访问。每个k_vecs将包含多个向量以供后续计算。每个向量将在每次内部迭代时设置。向量的分配使得一个warp中的相邻线程可以一起读取相邻内存,这再次促进了内存合并。例如,线程0将读取向量0,而线程1将读取向量1。在下一个内部循环中,线程0将读取向量2,线程1将读取向量3,依此类推。
你可能对整个流程还感到有些困惑。别担心,请继续阅读下一节"QK"部分。它会以更清晰、更高层次的方式阐述查询(query)和键(key)的计算流程。
QK¶
如下面的伪代码所示,在整个for循环块之前,我们获取一个token的查询数据并将其存储在q_vecs中。然后,在外层for循环中,我们遍历指向不同token的各个k_ptrs,并在内层for循环中准备k_vecs。最后,我们对q_vecs和每个k_vecs执行点乘运算。
q_vecs = ...
for ... {
k_ptr = ...
for ... {
k_vecs[i] = ...
}
...
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
}
如前所述,每个线程每次仅获取部分查询和键令牌数据。然而,在Qk_dot<>::dot中会发生跨线程组归约操作。因此这里返回的qk不仅是部分查询与键令牌的点乘结果,实际上是完整查询与键令牌数据的整体计算结果。
例如,如果HEAD_SIZE的值为128且THREAD_GROUP_SIZE为2,那么每个线程的k_vecs将总共包含64个元素。然而,返回的qk实际上是128个查询元素与128个键元素点乘的结果。如果您想了解更多关于点乘和归约的细节,可以参考Qk_dot<>::dot的实现。不过为了简洁起见,本文档将不涉及这部分内容。
Softmax¶
接下来,我们需要计算所有qk的归一化softmax,如上所示,其中每个代表一个qk。为此,我们必须获取qk_max()的归约值和所有qk的exp_sum()。该归约操作应在整个线程块范围内执行,涵盖查询token与所有上下文key token之间的结果。
qk_max 和 logits¶
在我们获取qk结果后,可以立即用qk设置临时的logits结果(最终logits应存储归一化的softmax结果)。同时,我们还可以比较并收集当前线程组计算的所有qk的qk_max值。
if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
请注意,这里的logits位于共享内存中,因此每个线程组将为其分配的上下文标记设置字段。总体而言,logits的大小应为上下文标记的数量。
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
接下来我们需要获取每个warp内的归约qk_max。主要思路是让warp内的线程相互通信,最终获得最大的qk值。
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
最后,我们可以通过比较该线程块中所有warp的qk_max,得到整个线程块缩减后的qk_max。然后需要将最终结果广播到每个线程。
exp_sum¶
类似于qk_max,我们也需要从整个线程块中获取归约求和值。
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
首先,对每个线程组的所有exp值求和,同时将logits的每个条目从qk转换为exp(qk - qk_max)。请注意,这里的qk_max已经是整个线程块中的最大qk值。然后我们可以像处理qk_max一样,对整个线程块的exp_sum进行归约操作。
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
最后,通过缩减后的qk_max和exp_sum,我们可以得到最终的归一化softmax结果logits。这个logits变量将在后续步骤中用于与值数据进行点乘运算。此时,它应该存储了所有分配上下文token的qk归一化softmax结果。
值¶
现在我们需要获取value数据并与logits进行点乘运算。与query和key不同,value数据没有线程组的概念。如图所示,与key token的内存布局不同,同一列中的元素对应同一个value token。对于一块value数据,它有HEAD_SIZE行和BLOCK_SIZE列,这些数据被分割成多个v_vecs。
每个线程总是每次从相同的V_VEC_SIZE个token中获取V_VEC_SIZE个元素。因此,单个线程通过多次内部迭代从不同行但相同列中检索多个v_vec。对于每个v_vec,都需要与对应的logits_vec进行点乘,后者也是从logits中取出的V_VEC_SIZE个元素。总体而言,通过多次内部迭代,每个warp将处理一个值token块。而通过多次外部迭代,整个上下文的值token都将被处理完毕。
float accs[NUM_ROWS_PER_THREAD];
for ... { // Iteration over different blocks.
logits_vec = ...
for ... { // Iteration over different rows.
v_vec = ...
...
accs[i] += dot(logits_vec, v_vec);
}
}
如上伪代码所示,在外层循环中,与k_ptr类似,logits_vec会遍历不同的块并从logits中读取V_VEC_SIZE个元素。在内层循环中,每个线程会从相同的token中读取V_VEC_SIZE个元素作为v_vec并执行点积运算。需要注意的是,在每次内层迭代中,线程会为相同的token获取不同头位置的元素。点积结果随后会累加到accs中。因此,accs的每个条目都映射到当前线程分配的头位置。
例如,如果BLOCK_SIZE为16且V_VEC_SIZE为8,每个线程一次会为8个token获取8个值元素。每个元素来自相同注意力头位置的不同token。如果HEAD_SIZE为128且WARP_SIZE为32,在每次内部循环中,一个warp需要获取WARP_SIZE * V_VEC_SIZE = 256个元素。这意味着一个warp需要执行128 * 16 / 256 = 8次内部迭代来处理整个值token块。每个线程中的每个accs包含8个在不同注意力头位置累加的元素。对于线程0,accs变量将包含8个元素,这些是来自所有分配的8个token的某个值头的第0、32...224个元素的累加结果。
LV¶
现在,我们需要在每个warp内对accs执行归约操作。这个过程允许每个线程为块中所有token的指定头位置累加accs值。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
接下来,我们对所有线程束中的accs执行归约操作,使每个线程都能获得所有上下文令牌指定头位置的accs累加值。请注意,每个线程中的accs仅存储所有上下文令牌整个头部部分元素的累加结果。但从整体来看,输出的所有结果都已完成计算,只是存储在不同的线程寄存器内存中。
Code
float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory.
...
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
dst[row_idx] = accs[i];
}
// Lower warps update the output.
const float* src = &out_smem[warp_idx * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
accs[i] += src[row_idx];
}
// Write out the accs.
}
输出¶
现在我们可以将所有计算结果从本地寄存器内存写入最终的输出全局内存。
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE;
首先,我们需要定义out_ptr变量,它指向分配序列和分配头部的起始地址。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
最后,我们需要遍历不同的分配头位置,并根据out_ptr写出相应的累积结果。
引用¶
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}