speechbrain.lobes.models.transformer.TransformerST 模块
用于ST的Transformer,采用SpeechBrain风格。
作者 * YAO FEI, CHENG 2021
摘要
类:
这是ST的transformer模型的实现。 |
参考
- class speechbrain.lobes.models.transformer.TransformerST.TransformerST(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: int | None = 31, bias: bool | None = True, encoder_module: str | None = 'transformer', conformer_activation: ~torch.nn.modules.module.Module | None = <class 'speechbrain.nnet.activations.Swish'>, attention_type: str | None = 'regularMHA', max_length: int | None = 2500, causal: bool | None = True, ctc_weight: float = 0.0, asr_weight: float = 0.0, mt_weight: float = 0.0, asr_tgt_vocab: int = 0, mt_src_vocab: int = 0)[source]
基础类:
TransformerASR这是ST的transformer模型的实现。
该架构基于论文“Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf
- Parameters:
tgt_vocab (int) – 词汇表的大小。
input_size (int) – 输入特征大小。
d_model (int, optional) – 嵌入维度大小。 (默认=512).
nhead (int, 可选) – 多头注意力模型中的头数(默认=8)。
num_encoder_layers (int, 可选) – 编码器中子编码层的数量(默认=6)。
num_decoder_layers (int, 可选) – 解码器中子解码层的数量(默认=6)。
d_ffn (int, 可选) – 前馈网络模型的维度(默认=2048)。
dropout (int, 可选) – 丢弃率的值(默认=0.1)。
activation (torch.nn.Module, optional) – FFN层的激活函数。 推荐使用:relu 或 gelu(默认=relu)。
positional_encoding (str, 可选) – 使用的位置编码类型。例如,‘fixed_abs_sine’ 表示固定的绝对位置编码。
normalize_before (bool, optional) – 是否在Transformer层中的MHA或FFN之前或之后应用归一化。 默认为True,因为这被证明可以带来更好的性能和训练稳定性。
kernel_size (int, optional) – 当使用Conformer时,卷积层中的核大小。
bias (bool, 可选) – 是否在Conformer卷积层中使用偏置。
encoder_module (str, 可选) – 在编码器中选择Conformer或Transformer。解码器固定为Transformer。
conformer_activation (torch.nn.Module, optional) – 在Conformer卷积层之后使用的激活模块。例如Swish、ReLU等。它必须是一个torch模块。
attention_type (str, 可选) – 在所有Transformer或Conformer层中使用的注意力层类型。 例如 regularMHA 或 RelPosMHA。
max_length (int, 可选) – 输入中目标和源序列的最大长度。 用于位置编码。
causal (bool, 可选) – 编码器是否应该是因果的(解码器总是因果的)。 如果是因果的,Conformer卷积层是因果的。
ctc_weight (float) – 用于asr任务的ctc权重
asr_weight (float) – 用于计算损失的asr任务的权重
mt_weight (float) – 用于计算损失的mt任务的权重
asr_tgt_vocab (int) – ASR目标语言的大小
mt_src_vocab (int) – mt源语言的大小
Example
>>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerST( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU, ... ctc_weight=1, asr_weight=0.3, ... ) >>> enc_out, dec_out = net.forward(src, tgt) >>> enc_out.shape torch.Size([8, 120, 512]) >>> dec_out.shape torch.Size([8, 120, 512])
- forward_asr(encoder_out, src, tgt, wav_len, pad_idx=0)[source]
该方法实现了asr任务的解码步骤
- Parameters:
encoder_out (torch.Tensor) – 编码器的表示(必需)。
src (torch.Tensor) – 输入序列(必需)。
tgt (torch.Tensor) – 解码器的序列(转录)(必需)。
wav_len (torch.Tensor) – 输入张量的长度(必需)。
pad_idx (int) –
标记的索引(默认=0)。
- Returns:
asr_decoder_out – ASR解码器的一步。
- Return type:
torch.Tensor
- forward_mt(src, tgt, pad_idx=0)[source]
此方法实现了mt任务的前向步骤
- Parameters:
src (torch.Tensor) – 编码器的输入序列(转录)(必需)。
tgt (torch.Tensor) – 解码器的输入序列(翻译)(必需)。
pad_idx (int) –
标记的索引(默认=0)。
- Returns:
encoder_out (torch.Tensor) – 编码器的输出
decoder_out (torch.Tensor) – 解码器的输出
- forward_mt_decoder_only(src, tgt, pad_idx=0)[source]
此方法使用wav2vec编码器为mt任务实现前向步骤 (与上述相同,但没有编码器堆栈)
- Parameters:
(转录) (src) – 来自w2v2编码器的输出特征
(翻译) (tgt) – 解码器的序列(必需)。
pad_idx (int) – <pad> 标记的索引(默认=0)。
- decode_asr(tgt, encoder_out)[source]
该方法实现了变压器模型的解码步骤。
- Parameters:
tgt (torch.Tensor) – 解码器的输入序列。
encoder_out (torch.Tensor) – 编码器的隐藏输出。
- Returns:
prediction (torch.Tensor) – 预测的输出。
multihead_attns (torch.Tensor) – 最后一步的注意力。
- make_masks_for_mt(src, tgt, pad_idx=0)[source]
此方法生成用于训练变压器模型的掩码。
- Parameters:
src (torch.Tensor) – 编码器的输入序列(必需)。
tgt (torch.Tensor) – 解码器的输入序列(必需)。
pad_idx (int) – <pad> 标记的索引(默认=0)。
- Returns:
src_key_padding_mask (torch.Tensor) – 由于填充而需要屏蔽的时间步
tgt_key_padding_mask (torch.Tensor) – 由于填充而需要屏蔽的时间步
src_mask (torch.Tensor) – 由于因果关系而需要屏蔽的时间步
tgt_mask (torch.Tensor) – 由于因果关系而需要屏蔽的时间步