speechbrain.nnet.loss.guidedattn_loss 模块

引导注意力损失实现

这种损失可以用于加速模型的训练,其中输入和输出之间的对应关系大致是线性的,并且注意力对齐预计大致是对角线的,例如从字形到音素和文本到语音。

作者 * Artem Ploujnikov 2021

摘要

类:

GuidedAttentionLoss

一种损失实现,强制注意力矩阵接近对角线,对远离对角线的区域施加逐渐增大的惩罚。

参考

class speechbrain.nnet.loss.guidedattn_loss.GuidedAttentionLoss(sigma=0.2)[source]

基础:Module

一种损失实现,强制注意力矩阵接近对角线,对远离对角线的区域施加逐渐增大的惩罚。这对于序列到序列模型非常有用,其中输出序列预期与输入序列紧密对应,例如TTS或G2P。

https://arxiv.org/abs/1710.08969

该实现受到R9Y9 DeepVoice3模型的启发 https://github.com/r9y9/deepvoice3_pytorch

它应该大致相当于它;然而,它已经完全向量化了。

Parameters:

sigma (float) – 引导注意力权重

Example

注意:在实际场景中,input_lengths 和 target_lengths 会来自数据批次,而 alignments 会来自模型 >>> import torch >>> from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss >>> loss = GuidedAttentionLoss(sigma=0.2) >>> input_lengths = torch.tensor([2, 3]) >>> target_lengths = torch.tensor([3, 4]) >>> alignments = torch.tensor( … [ … [ … [0.8, 0.2, 0.0], … [0.4, 0.6, 0.0], … [0.2, 0.8, 0.0], … [0.0, 0.0, 0.0], … ], … [ … [0.6, 0.2, 0.2], … [0.1, 0.7, 0.2], … [0.3, 0.4, 0.3], … [0.2, 0.3, 0.5], … ], … ] … ) >>> loss(alignments, input_lengths, target_lengths) tensor(0.1142)

forward(attention, input_lengths, target_lengths, max_input_len=None, max_target_len=None)[source]

计算单个批次的引导注意力损失

Parameters:
  • 注意 (torch.Tensor) – 一个填充的注意力/对齐矩阵 (batch, targets, inputs)

  • input_lengths (torch.tensor) – 一个 (batch, lengths) 的输入长度张量

  • target_lengths (torch.tensor) – 一个 (batch, lengths) 的目标长度张量

  • max_input_len (int) – 最大输入长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。

  • max_target_len (int) – 最大目标长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。

Returns:

loss – 一个包含损失值的单元素张量

Return type:

torch.Tensor

guided_attentions(input_lengths, target_lengths, max_input_len=None, max_target_len=None)[source]

计算引导注意力矩阵

Parameters:
  • input_lengths (torch.Tensor) – 输入长度的张量

  • target_lengths (torch.Tensor) – 目标长度的张量

  • max_input_len (int) – 最大输入长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。

  • max_target_len (int) – 最大目标长度 - 可选,如果未计算,将设置为target_lengths的最大值。在使用数据并行时,可能需要显式设置它。

Returns:

soft_mask – 引导注意力张量,形状为 (batch, max_input_len, max_target_len)

Return type:

torch.Tensor