speechbrain.nnet.loss.guidedattn_loss 模块
引导注意力损失实现
这种损失可以用于加速模型的训练,其中输入和输出之间的对应关系大致是线性的,并且注意力对齐预计大致是对角线的,例如从字形到音素和文本到语音。
作者 * Artem Ploujnikov 2021
摘要
类:
一种损失实现,强制注意力矩阵接近对角线,对远离对角线的区域施加逐渐增大的惩罚。 |
参考
- class speechbrain.nnet.loss.guidedattn_loss.GuidedAttentionLoss(sigma=0.2)[source]
基础:
Module一种损失实现,强制注意力矩阵接近对角线,对远离对角线的区域施加逐渐增大的惩罚。这对于序列到序列模型非常有用,其中输出序列预期与输入序列紧密对应,例如TTS或G2P。
https://arxiv.org/abs/1710.08969
该实现受到R9Y9 DeepVoice3模型的启发 https://github.com/r9y9/deepvoice3_pytorch
它应该大致相当于它;然而,它已经完全向量化了。
- Parameters:
sigma (float) – 引导注意力权重
Example
注意:在实际场景中,input_lengths 和 target_lengths 会来自数据批次,而 alignments 会来自模型 >>> import torch >>> from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss >>> loss = GuidedAttentionLoss(sigma=0.2) >>> input_lengths = torch.tensor([2, 3]) >>> target_lengths = torch.tensor([3, 4]) >>> alignments = torch.tensor( … [ … [ … [0.8, 0.2, 0.0], … [0.4, 0.6, 0.0], … [0.2, 0.8, 0.0], … [0.0, 0.0, 0.0], … ], … [ … [0.6, 0.2, 0.2], … [0.1, 0.7, 0.2], … [0.3, 0.4, 0.3], … [0.2, 0.3, 0.5], … ], … ] … ) >>> loss(alignments, input_lengths, target_lengths) tensor(0.1142)
- forward(attention, input_lengths, target_lengths, max_input_len=None, max_target_len=None)[source]
计算单个批次的引导注意力损失
- Parameters:
注意 (torch.Tensor) – 一个填充的注意力/对齐矩阵 (batch, targets, inputs)
input_lengths (torch.tensor) – 一个 (batch, lengths) 的输入长度张量
target_lengths (torch.tensor) – 一个 (batch, lengths) 的目标长度张量
max_input_len (int) – 最大输入长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。
max_target_len (int) – 最大目标长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。
- Returns:
loss – 一个包含损失值的单元素张量
- Return type:
torch.Tensor