Shortcuts

Transformer自注意力层

class torchtune.modules.TransformerSelfAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, sa_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, sa_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]

Transformer层源自Llama2模型。在注意力FF层之前应用归一化。

Parameters:
  • attn (MultiHeadAttention) – 注意力模块。

  • mlp (nn.Module) – 前馈模块。

  • sa_norm (可选[nn.Module]) – 在自注意力之前应用的归一化。

  • mlp_norm (可选[nn.Module]) – 在应用前馈层之前要应用的归一化。

  • sa_scale (可选[nn.Module]) – 用于缩放自注意力输出的模块。

  • mlp_scale (可选[nn.Module]) – 用于缩放前馈输出的模块。

caches_are_enabled() bool[source]

检查self.attn上的键值缓存是否启用。 参见 :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`。

caches_are_setup() bool[source]

检查键值缓存是否在self.attn上设置。 参见 :func:~torchtune.modules.TransformerDecoder.caches_are_setup`。

forward(x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, **kwargs: Dict) Tensor[source]
Parameters:
  • x (torch.Tensor) – 输入张量,形状为 [batch_size x seq_length x embed_dim]

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.encoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • input_pos (可选[torch.Tensor]) – 可选的张量,包含每个标记的位置ID。在训练期间,这用于指示每个标记相对于其样本的位置,形状为 [b x s]。在推理期间,这表示当前标记的位置。如果未提供,则假定标记的索引为其位置ID。默认值为None。

  • **kwargs (Dict) – 与自注意力无关的transformer层输入。

Returns:

输出张量与输入形状相同

[batch_size x seq_length x embed_dim]

Return type:

torch.Tensor

reset_cache()[source]

重置键值缓存。

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int, decoder_max_seq_len: int) None[source]

为注意力计算设置键值缓存。

Parameters:
  • batch_size (int) – 缓存的批量大小。

  • dtype (torch.dpython:type) – 缓存的dtype。

  • encoder_max_seq_len (int) – 此参数在此层中被忽略。

  • decoder_max_seq_len (int) – 最大缓存序列长度。