多头注意力机制¶
- class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]¶
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]¶
- Note::
请参考
forward()获取更多信息
- Parameters
query (Tensor) – 将查询和一组键值对映射到输出。 参见“Attention Is All You Need”以获取更多详细信息。
key (Tensor) – 将查询和一组键值对映射到输出。 参见“Attention Is All You Need”了解更多详情。
值 (张量) – 将查询和一组键值对映射到输出。 参见“Attention Is All You Need”以获取更多详细信息。
key_padding_mask (可选[张量]) – 如果提供,指定键中的填充元素将被注意力机制忽略。当给定一个二进制掩码并且值为True时,注意力层上的相应值将被忽略。
need_weights (布尔值) – 输出 attn_output_weights。
attn_mask (可选[张量]) – 2D 或 3D 掩码,用于防止对某些位置进行注意力计算。2D 掩码将广播到所有批次,而 3D 掩码允许为每个批次的条目指定不同的掩码。
- Return type
- Shape:
输入:
查询: 其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first是True。键: , 其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first是True。值: 其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first是True。key_padding_mask: 其中 N 是批次大小,S 是源序列长度。 如果提供了一个 BoolTensor,值为
True的位置将被忽略,而值为False的位置将保持不变。attn_mask: 2D mask 其中 L 是目标序列长度,S 是源序列长度。 3D mask 其中 N 是批次大小,L 是目标序列长度, S 是源序列长度。attn_mask 确保位置 i 被允许关注未被掩码的位置。如果提供了 BoolTensor,位置为
True则不允许关注,而False值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。is_causal: 如果指定,应用因果掩码作为注意力掩码。与提供attn_mask互斥。 默认值:
False。average_attn_weights: 如果为真,表示返回的
attn_weights应该在头之间进行平均。否则,attn_weights将分别提供给每个头。请注意,此标志仅在need_weights=True.时有效。默认值:True(即在头之间平均权重)输出:
attn_output: 其中 L 是目标序列长度,N 是批次大小, E 是嵌入维度。 如果
batch_first是True。attn_output_weights: 如果
average_attn_weights=True,返回平均的注意力权重 跨头的形状 ,其中 N 是批量大小,L 是目标序列长度, S 是源序列长度。如果average_attn_weights=False,返回每个头的注意力权重 形状 。