多头注意力机制¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]¶
允许模型联合关注来自不同表示子空间的信息。
论文中描述的方法: Attention Is All You Need。
多头注意力机制定义为:
其中 。
nn.MultiHeadAttention将会尽可能使用优化的scaled_dot_product_attention()实现。除了支持新的
scaled_dot_product_attention()函数外,为了加速推理,MHA 将使用 支持嵌套张量的快速路径推理,当且仅当:正在计算自注意力(即,
query、key和value是相同的张量)。输入是批量处理的(3D),且
batch_first==True要么自动求导被禁用(使用
torch.inference_mode或torch.no_grad),要么没有张量参数requires_grad训练被禁用(使用
.eval())add_bias_kv是Falseadd_zero_attn是Falsekdim和vdim等于embed_dim如果传递了 NestedTensor,则不会传递
key_padding_mask也不会传递attn_maskautocast 已禁用
如果正在使用优化的推理快速路径实现,可以传递一个 NestedTensor 用于
query/key/value来更高效地表示填充,而不是使用填充掩码。在这种情况下,将返回一个 NestedTensor,并且可以预期获得与输入中填充部分比例成正比的额外加速。- Parameters
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。注意
embed_dim将被分割到num_heads中(即每个头将具有维度embed_dim // num_heads)。dropout – 在
attn_output_weights上的 dropout 概率。默认值:0.0(无 dropout)。偏差 – 如果指定,将偏差添加到输入/输出投影层。默认值:
True。add_bias_kv – 如果指定,则在 dim=0 处向键和值序列添加偏差。默认值:
False。add_zero_attn – 如果指定,在dim=1处向键和值序列添加一个新的零批次。 默认值:
False。kdim – 键的总特征数。默认值:
None(使用kdim=embed_dim)。vdim – 值的总特征数。默认值:
None(使用vdim=embed_dim)。batch_first – 如果
True,则输入和输出张量以 (batch, seq, feature) 的形式提供。默认值:False(seq, batch, feature)。
示例:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]¶
使用查询、键和值嵌入计算注意力输出。
支持可选参数用于填充、掩码和注意力权重。
- Parameters
查询 (张量) – 查询嵌入的形状为 对于非批量输入, 当
batch_first=False或 当batch_first=True,其中 是目标序列长度, 是批量大小,并且 是查询嵌入维度embed_dim。 查询与键值对进行比较以生成输出。 有关更多详细信息,请参阅“Attention Is All You Need”。key (Tensor) – 形状为 的键嵌入,用于非批量输入, 当
batch_first=False或 当batch_first=True,其中 是源序列长度, 是批量大小,并且 是键嵌入维度kdim。 参见“Attention Is All You Need”了解更多详情。值 (张量) – 形状为 的值嵌入,用于未批处理的输入, 当
batch_first=False或 当batch_first=True,其中 是源序列长度, 是批量大小,并且 是值嵌入维度vdim。 参见“Attention Is All You Need”了解更多详情。key_padding_mask (可选[张量]) – 如果指定,形状为 的掩码,指示在注意力计算中忽略
key中的哪些元素(即视为“填充”)。对于未批处理的 query,形状应为 。 支持二进制和浮点掩码。 对于二进制掩码,True值表示相应的key值将在注意力计算中被忽略。对于浮点掩码,它将直接添加到相应的key值。need_weights (bool) – 如果指定,除了返回
attn_outputs外,还会返回attn_output_weights。 设置need_weights=False以使用优化的scaled_dot_product_attention并在MHA中获得最佳性能。 默认值:True。attn_mask (可选[张量]) – 如果指定,一个2D或3D的掩码,防止某些位置进行注意力计算。必须为形状 或 ,其中 是批量大小, 是目标序列长度,并且 是源序列长度。2D掩码将在批量中广播,而3D掩码允许为批量中的每个条目使用不同的掩码。 支持二进制和浮点掩码。对于二进制掩码,
True值表示不允许相应位置进行注意力计算。对于浮点掩码,掩码值将被添加到注意力权重中。 如果同时提供了 attn_mask 和 key_padding_mask,它们的类型应匹配。average_attn_weights (bool) – 如果为真,表示返回的
attn_weights应该在头之间进行平均。否则,attn_weights将分别提供给每个头。请注意,此标志仅在need_weights=True时生效。默认值:True(即在头之间平均权重)is_causal (bool) – 如果指定,则将因果掩码作为注意力掩码应用。 默认值:
False。 警告:is_causal提供了一个提示,即attn_mask是因果掩码。提供错误的提示可能导致 错误的执行,包括前向和后向兼容性。
- Return type
- Outputs:
attn_output - 注意力输出的形状为 当输入未批处理时, 当
batch_first=False或 当batch_first=True, 其中 是目标序列长度, 是批量大小, 是嵌入维度embed_dim。attn_output_weights - 仅在
need_weights=True时返回。如果average_attn_weights=True, 返回形状为 的注意力权重,当输入未批处理时, 或 ,其中 是批量大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False,返回每个头的注意力权重, 形状为 当输入未批处理时, 或 。
注意
batch_first 参数在未批处理的输入中被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[源代码]¶
确定掩码类型并在必要时合并掩码。
如果只提供了一个掩码,则返回该掩码及其对应的掩码类型。如果提供了两个掩码,它们都将被扩展为形状
(batch_size, num_heads, seq_len, seq_len),通过逻辑或组合,并返回掩码类型 2 :param attn_mask: 形状为(seq_len, seq_len)的注意力掩码,掩码类型 0 :param key_padding_mask: 形状为(batch_size, seq_len)的填充掩码,掩码类型 1 :param query: 形状为(batch_size, seq_len, embed_dim)的查询嵌入- Returns
合并掩码 mask_type: 合并掩码类型 (0, 1, 或 2)
- Return type
合并的掩码