speechbrain.nnet.attention 模块
实现注意力模块的库。
- Authors
周珏洁 2020
钟建元 2020
洛伦·卢戈斯奇 2020
萨穆埃莱·科内尔 2020
摘要
类:
该类实现了用于序列到序列学习的内容注意力模块。 |
|
该类实现了一个用于序列到序列学习的单头键值注意力模块。 |
|
该类实现了用于序列到序列学习的位置感知注意力模块。 |
|
该类是torch.nn.MultiHeadAttention的多头注意力机制的封装。 |
|
该类实现了“Attention Is All You Need”中的位置前馈模块。 |
|
用于 |
|
该类实现了类似于Transformer XL中的相对多头实现 https://arxiv.org/pdf/1901.02860.pdf |
参考
- class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]
基础:
Module该类实现了用于序列到序列学习的内容注意力模块。
参考:通过联合学习对齐和翻译的神经机器翻译,Bahdanau等人。https://arxiv.org/pdf/1409.0473.pdf
- Parameters:
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = ContentBasedAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]
基础:
Module该类实现了用于序列到序列学习的位置感知注意力模块。
参考:基于注意力的语音识别模型,Chorowski等人。 https://arxiv.org/pdf/1506.07503.pdf
- Parameters:
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = LocationAwareAttention( ... enc_dim=20, ... dec_dim=25, ... attn_dim=30, ... output_dim=5, ... conv_channels=10, ... kernel_size=100) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]
基础:
Module该类实现了一个用于序列到序列学习的单头键值注意力模块。
参考:Vaswani等人的《Attention Is All You Need》,第3.2.1节
- Parameters:
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = KeyValueAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- class speechbrain.nnet.attention.RelPosEncXL(emb_dim: int, dtype: dtype = torch.float32)[source]
基础:
Module用于
RelPosMHAXL的相对位置编码。- Parameters:
emb_dim (int) – 嵌入的大小,它也控制位置嵌入的最后一个维度的大小
dtype (torch.dtype, optional) – 如果未指定,默认为
torch.float32。控制输出嵌入的数据类型(但不影响计算的精度,计算精度仍为torch.float32)。
- class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]
基础:
Module该类实现了类似于Transformer XL中的相对多头实现 https://arxiv.org/pdf/1901.02860.pdf
- Parameters:
Example
>>> inputs = torch.rand([6, 60, 512]) >>> pos_emb = torch.rand([1, 2*60-1, 512]) >>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs, pos_emb) >>> outputs.shape torch.Size([6, 60, 512])
- forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]
计算注意力。
- Parameters:
query (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度, B 是批次大小,E 是嵌入维度。
key (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。
value (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。
pos_embs (torch.Tensor) – 双向正弦位置嵌入张量 (1, 2*S-1, E),其中 S 是源序列和目标序列长度之间的最大长度,E 是嵌入维度。
key_padding_mask (torch.Tensor) – (B, S) 其中 B 是批量大小,S 是源序列长度。如果提供了 ByteTensor,非零位置将被忽略,而零位置将保持不变。如果提供了 BoolTensor,值为 True 的位置将被忽略,而值为 False 的位置将保持不变。
attn_mask (torch.Tensor) – 2D 掩码 (L, S),其中 L 是目标序列长度,S 是源序列长度。 3D 掩码 (N*num_heads, L, S),其中 N 是批量大小,L 是目标序列长度,S 是源序列长度。attn_mask 确保位置 i 可以关注未掩码的位置。如果提供了 ByteTensor,非零位置不允许关注,而零位置将保持不变。如果提供了 BoolTensor,True 的位置不允许关注,而 False 的值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。
return_attn_weights (bool) – 是否额外返回注意力权重。
- Returns:
out (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度,B 是批量大小,E 是嵌入维度。
attn_score (torch.Tensor) – (B, L, S) 其中 B 是批量大小,L 是目标序列长度,S 是源序列长度。
- class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]
基础:
Module该类是torch.nn.MultiHeadAttention的多头注意力机制的封装。
参考:https://pytorch.org/docs/stable/nn.html
- Parameters:
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool = True, pos_embs: Tensor | None = None)[source]
计算注意力。
- Parameters:
query (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度, B 是批次大小,E 是嵌入维度。
key (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。
value (torch.Tensor) – (B, S, E) 其中 S 是源序列长度, B 是批次大小,E 是嵌入维度。
attn_mask (torch.Tensor, optional) – 2D 掩码 (L, S),其中 L 是目标序列长度,S 是源序列长度。 3D 掩码 (N*num_heads, L, S),其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。attn_mask 确保位置 i 可以关注未掩码的位置。如果提供了 ByteTensor,非零位置不允许关注,而零位置将保持不变。如果提供了 BoolTensor,True 的位置不允许关注,而 False 的值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。
key_padding_mask (torch.Tensor, optional) – (B, S) 其中 B 是批量大小,S 是源序列长度。如果提供了 ByteTensor,非零位置将被忽略,而零位置将保持不变。如果提供了 BoolTensor,值为 True 的位置将被忽略,而值为 False 的位置将保持不变。
return_attn_weights (bool, 可选) – 如果为True,则额外返回注意力权重,否则不返回。
pos_embs (torch.Tensor, optional) – 添加到注意力图的位置嵌入,形状为 (L, S, E) 或 (L, S, 1)。
- Returns:
attn_output (torch.Tensor) – (B, L, E) 其中 L 是目标序列长度,B 是批次大小,E 是嵌入维度。
attn_output_weights (torch.Tensor) – (B, L, S) 其中 B 是批次大小,L 是目标序列长度,S 是源序列长度。 只有在
return_attn_weights=True(默认为 True)时才会返回。
- class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]
基础:
Module该类实现了“Attention Is All You Need”中的位置前馈模块。
- Parameters:
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1]) >>> outputs = net(inputs) >>> outputs.shape torch.Size([8, 60, 512])