多头注意力机制¶
- class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]¶
支持分组查询注意力(GQA)的多头注意力层,引入于https://arxiv.org/abs/2305.13245v1。
GQA是多头注意力(MHA)的一个版本,它通过为每个键和值头分组n个查询头,使用比查询头更少的键/值头。多查询注意力是一个极端版本,其中我们有一个由所有查询头共享的单个键和值头。
以下是使用 num_heads = 4 的 MHA、GQA 和 MQA 的示例
(文档的功劳: litgpt.Config).
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1- Parameters:
embed_dim (int) – 模型的嵌入维度
num_heads (int) – 查询头的数量。对于MHA来说,这也是键和值的头的数量
num_kv_heads (int) – 键和值头的数量。用户应确保
num_heads % num_kv_heads == 0。对于标准的MHA,设置num_kv_heads == num_heads, 对于GQA,设置num_kv_heads < num_heads,对于MQA,设置num_kv_heads == 1。head_dim (int) – 每个头的维度,通过
embed_dim // num_heads计算得出。q_proj (nn.Module) – 查询的投影层。
k_proj (nn.Module) – 用于键的投影层。
v_proj (nn.Module) – 值的投影层。
output_proj (nn.Module) – 输出投影层。
pos_embeddings (可选[nn.Module]) – 位置嵌入层,例如 RotaryPositionalEmbeddings。
q_norm (可选[nn.Module]) – 查询的归一化层,例如 RMSNorm。在解码时,这会在从 kv_cache 更新之前应用。这意味着它只支持令牌级别的归一化,而不支持批次或序列级别的归一化。
k_norm (可选[nn.Module]) – 键的归一化层,如果设置了q_norm,则必须设置。
kv_cache (可选[KVCache]) – 用于缓存键和值的KVCache对象
max_seq_len (int) – 模型支持的最大序列长度。 这是计算RoPE缓存所需的。默认值:4096。
is_causal (bool) – 当未提供掩码时,将默认掩码设置为因果掩码
attn_dropout (float) – 传递给scaled_dot_product_attention函数的dropout值。 默认值为0.0。
- Raises:
ValueError – 如果
num_heads % num_kv_heads != 0ValueError – 如果
embed_dim % num_heads != 0ValueError – 如果
attn_dropout < 0或attn_dropout > 1ValueError – 如果定义了q_norm而没有定义k_norm,反之亦然
- forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]¶
- Parameters:
x (torch.Tensor) – 输入张量,形状为 [b x s_x x d],用于查询
y (可选[torch.Tensor]) – 第二个输入张量,形状为 [b x s_y x d],是 k 和 v 的输入。对于自注意力机制,x=y。仅在启用 kv_cache 时为可选。
mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. Either:
A boolean tensor with shape
[b x s x s],[b x s x self.encoder_max_cache_seq_len], or[b x s x self.decoder_max_cache_seq_len]if using KV-cacheing with encoder/decoder layers. A value of True in rowiand columnjmeans tokeniattends to tokenj. A value of False means tokenidoes not attend to tokenj. If no mask is specified, a causal mask is used by default.A
BlockMaskfor document masking in a packed sequence created via create_block_mask. We useflex_attention()when computing attention with block masks. Default is None.input_pos (可选[torch.Tensor]) – 可选的张量,包含每个标记的位置ID。在训练期间,这用于指示每个标记相对于其样本的位置,形状为 [b x s]。在推理期间,这表示当前标记的位置。如果未提供,则假定标记的索引为其位置ID。默认值为None。
- Raises:
ValueError – 如果没有
y输入且kv_cache未启用。- Returns:
应用注意力机制后的输出张量
- Return type:
- Notation used for tensor shapes:
b: 批量大小
s_x: x的序列长度
s_y: y的序列长度
n_h: 头数
n_kv: 键值头数量
d: 嵌入维度
h_d: 头部维度