Transformer交叉注意力层¶
- class torchtune.modules.TransformerCrossAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, ca_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, ca_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]¶
交叉注意力Transformer层遵循与TransformerSelfAttentionLayer相同的惯例。 在注意力和FF层之前应用归一化。
- Parameters:
attn (MultiHeadAttention) – 注意力模块。
mlp (nn.Module) – 前馈模块。
ca_norm (可选[nn.Module]) – 在交叉注意力之前应用的归一化。
mlp_norm (可选[nn.Module]) – 在应用前馈层之前要应用的归一化。
ca_scale (可选[nn.Module]) – 用于缩放交叉注意力输出的模块。
mlp_scale (可选[nn.Module]) – 用于缩放前馈输出的模块。
- Raises:
AssertionError – 如果设置了attn.pos_embeddings。
- caches_are_enabled() bool[source]¶
检查
self.attn上的键值缓存是否启用。 参见 :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`。
- caches_are_setup() bool[source]¶
检查键值缓存是否在
self.attn上设置。 参见 :func:~torchtune.modules.TransformerDecoder.caches_are_setup`。
- forward(x: Tensor, *, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, **kwargs: Dict) Tensor[source]¶
- Parameters:
x (torch.Tensor) – 输入张量,形状为 [batch_size x seq_length x embed_dim]
encoder_input (可选[torch.Tensor]) – 来自编码器的可选输入嵌入。形状为 [batch_size x token_sequence x embed_dim]
encoder_mask (Optional[torch.Tensor]) – 布尔张量,定义了令牌和编码器嵌入之间的关系矩阵。位置i,j处的True值表示令牌i可以关注解码器中的嵌入j。掩码的形状为[batch_size x token_sequence x embed_sequence]。默认值为None。
**kwargs (Dict) – 与自注意力无关的transformer层输入。
- Returns:
- 输出张量与输入形状相同
[batch_size x seq_length x embed_dim]
- Return type: