Shortcuts

多头注意力机制

class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]

支持分组查询注意力(GQA)的多头注意力层,引入于https://arxiv.org/abs/2305.13245v1

GQA是多头注意力(MHA)的一个版本,它通过为每个键和值头分组n个查询头,使用比查询头更少的键/值头。多查询注意力是一个极端版本,其中我们有一个由所有查询头共享的单个键和值头。

以下是使用 num_heads = 4 的 MHA、GQA 和 MQA 的示例

(文档的功劳: litgpt.Config).

┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
Parameters:
  • embed_dim (int) – 模型的嵌入维度

  • num_heads (int) – 查询头的数量。对于MHA来说,这也是键和值的头的数量

  • num_kv_heads (int) – 键和值头的数量。用户应确保 num_heads % num_kv_heads == 0。对于标准的MHA,设置num_kv_heads == num_heads, 对于GQA,设置num_kv_heads < num_heads,对于MQA,设置num_kv_heads == 1

  • head_dim (int) – 每个头的维度,通过 embed_dim // num_heads 计算得出。

  • q_proj (nn.Module) – 查询的投影层。

  • k_proj (nn.Module) – 用于键的投影层。

  • v_proj (nn.Module) – 值的投影层。

  • output_proj (nn.Module) – 输出投影层。

  • pos_embeddings (可选[nn.Module]) – 位置嵌入层,例如 RotaryPositionalEmbeddings。

  • q_norm (可选[nn.Module]) – 查询的归一化层,例如 RMSNorm。在解码时,这会在从 kv_cache 更新之前应用。这意味着它只支持令牌级别的归一化,而不支持批次或序列级别的归一化。

  • k_norm (可选[nn.Module]) – 键的归一化层,如果设置了q_norm,则必须设置。

  • kv_cache (可选[KVCache]) – 用于缓存键和值的KVCache对象

  • max_seq_len (int) – 模型支持的最大序列长度。 这是计算RoPE缓存所需的。默认值:4096。

  • is_causal (bool) – 当未提供掩码时,将默认掩码设置为因果掩码

  • attn_dropout (float) – 传递给scaled_dot_product_attention函数的dropout值。 默认值为0.0。

Raises:
  • ValueError – 如果 num_heads % num_kv_heads != 0

  • ValueError – 如果 embed_dim % num_heads != 0

  • ValueError – 如果 attn_dropout < 0attn_dropout > 1

  • ValueError – 如果定义了q_norm而没有定义k_norm,反之亦然

forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
Parameters:
  • x (torch.Tensor) – 输入张量,形状为 [b x s_x x d],用于查询

  • y (可选[torch.Tensor]) – 第二个输入张量,形状为 [b x s_y x d],是 k 和 v 的输入。对于自注意力机制,x=y。仅在启用 kv_cache 时为可选。

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.decoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • input_pos (可选[torch.Tensor]) – 可选的张量,包含每个标记的位置ID。在训练期间,这用于指示每个标记相对于其样本的位置,形状为 [b x s]。在推理期间,这表示当前标记的位置。如果未提供,则假定标记的索引为其位置ID。默认值为None。

Raises:

ValueError – 如果没有 y 输入且 kv_cache 未启用。

Returns:

应用注意力机制后的输出张量

Return type:

torch.Tensor

Notation used for tensor shapes:
  • b: 批量大小

  • s_x: x的序列长度

  • s_y: y的序列长度

  • n_h: 头数

  • n_kv: 键值头数量

  • d: 嵌入维度

  • h_d: 头部维度

reset_cache()[source]

重置键值缓存。

setup_cache(batch_size: int, dtype: dtype, max_seq_len: int) None[source]

为注意力计算设置键值缓存。如果在kv_cache已经设置后调用,此操作将被跳过。

Parameters:
  • batch_size (int) – 缓存的批量大小。

  • dtype (torch.dpython:type) – 缓存的dtype。

  • max_seq_len (int) – 模型将运行的最大序列长度。