speechbrain.lobes.models.g2p.model 模块
用于字形到音素的注意力RNN模型
- Authors
Mirco Ravanelli 2021
阿尔乔姆·普洛日尼科夫 2021
摘要
类:
注意力RNN编码器-解码器模型 |
|
一个基于Transformer的字素到音素模型 |
|
一个小型编码器模块,用于降低维度并规范化词嵌入 |
函数:
创建一个虚拟的音素序列 |
|
计算输入维度(用于hparam文件) |
参考
- class speechbrain.lobes.models.g2p.model.AttentionSeq2Seq(enc, encoder_emb, emb, dec, lin, out, bos_token=0, use_word_emb=False, word_emb_enc=None)[source]
基础:
Module注意力RNN编码器-解码器模型
- Parameters:
- forward(grapheme_encoded, phn_encoded=None, word_emb=None)[source]
计算前向传播
- Parameters:
grapheme_encoded (torch.Tensor) – 以Torch张量编码的字素
phn_encoded (torch.Tensor) – 编码的音素
word_emb (torch.Tensor) – 词嵌入(可选)
- Returns:
p_seq (torch.Tensor) – 每个位置中token概率的(批次 x 位置 x token)张量
char_lens (torch.Tensor) – 字符序列长度的张量
encoder_out – 编码器的原始输出
- class speechbrain.lobes.models.g2p.model.WordEmbeddingEncoder(word_emb_dim, word_emb_enc_dim, norm=None, norm_type=None)[source]
基础:
Module一个小型编码器模块,用于降低维度并规范化词嵌入
- Parameters:
- forward(emb)[source]
计算嵌入的前向传播
- Parameters:
emb (torch.Tensor) – 原始的词嵌入
- Returns:
emb_enc – 编码的词嵌入
- Return type:
torch.Tensor
- NORMS = {'batch': <class 'speechbrain.nnet.normalization.BatchNorm1d'>, 'instance': <class 'speechbrain.nnet.normalization.InstanceNorm1d'>, 'layer': <class 'speechbrain.nnet.normalization.LayerNorm'>}
- class speechbrain.lobes.models.g2p.model.TransformerG2P(emb, encoder_emb, char_lin, phn_lin, lin, out, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, custom_src_module=None, custom_tgt_module=None, positional_encoding='fixed_abs_sine', normalize_before=True, kernel_size=15, bias=True, encoder_module='transformer', attention_type='regularMHA', max_length=2500, causal=False, pad_idx=0, encoder_kdim=None, encoder_vdim=None, decoder_kdim=None, decoder_vdim=None, use_word_emb=False, word_emb_enc=None)[source]
基础类:
TransformerInterface基于Transformer的字母到音素模型
- Parameters:
emb (torch.nn.Module) – 嵌入模块
encoder_emb (torch.nn.Module) – 编码器嵌入模块
char_lin (torch.nn.Module) – 一个将输入连接到transformer的线性模块
phn_lin (torch.nn.Module) – 一个线性模块,将输出连接到变压器
out (torch.nn.Module) – 解码器模块(通常是Softmax)
lin (torch.nn.Module) – 用于输出的线性模块
d_model (int) – 编码器/解码器输入中期望的特征数量(默认=512)。
nhead (int) – 多头注意力模型中的头数(默认=8)。
num_encoder_layers (int, optional) – 编码器中的编码层数量。
num_decoder_layers (int, optional) – 解码器中解码层的数量。
dim_ffn (int, optional) – 前馈网络模型隐藏层的维度。
dropout (int, 可选) – 丢弃率的值。
activation (torch.nn.Module, optional) – 前馈网络层的激活函数, 例如,relu 或 gelu 或 swish。
custom_src_module (torch.nn.Module, optional) – 处理源特征以符合预期特征维度的模块。
custom_tgt_module (torch.nn.Module, optional) – 处理源特征到预期特征维度的模块。
positional_encoding (str, 可选) – 使用的位置编码类型。例如,‘fixed_abs_sine’ 表示固定的绝对位置编码。
normalize_before (bool, optional) – 是否在Transformer层中的MHA或FFN之前或之后应用归一化。 默认为True,因为这被证明可以带来更好的性能和训练稳定性。
kernel_size (int, optional) – 当使用Conformer时,卷积层中的核大小。
bias (bool, 可选) – 是否在Conformer卷积层中使用偏置。
encoder_module (str, 可选) – 在编码器中选择Conformer或Transformer。解码器固定为Transformer。
conformer_activation (torch.nn.Module, optional) – 在Conformer卷积层之后使用的激活模块。例如Swish、ReLU等。它必须是一个torch模块。
attention_type (str, 可选) – 在所有Transformer或Conformer层中使用的注意力层类型。 例如 regularMHA 或 RelPosMHA。
max_length (int, 可选) – 输入中目标和源序列的最大长度。 用于位置编码。
causal (bool, 可选) – 编码器是否应该是因果的(解码器总是因果的)。 如果是因果的,Conformer卷积层是因果的。
pad_idx (int) – 填充索引(用于掩码)
encoder_kdim (int, optional) – 编码器键的维度。
encoder_vdim (int, optional) – 编码器值的维度。
decoder_kdim (int, optional) – 解码器键的维度。
decoder_vdim (int, 可选) – 解码器的值维度。
- forward(grapheme_encoded, phn_encoded=None, word_emb=None)[source]
计算前向传播
- Parameters:
grapheme_encoded (torch.Tensor) – 以Torch张量编码的字素
phn_encoded (torch.Tensor) – 编码的音素
word_emb (torch.Tensor) – 词嵌入(如果适用)
- Returns:
p_seq (torch.Tensor) – 序列中各个标记的对数概率
char_lens (torch.Tensor) – 字符长度语法
encoder_out (torch.Tensor) – 编码器状态
attention (torch.Tensor) – 注意力状态
- make_masks(src, tgt, src_len=None, pad_idx=0)[source]
此方法生成用于训练变压器模型的掩码。
- Parameters:
src (torch.Tensor) – 编码器的输入序列(必需)。
tgt (torch.Tensor) – 解码器的输入序列(必需)。
src_len (torch.Tensor) – 与src张量对应的长度。
pad_idx (int) – <pad> 标记的索引(默认=0)。
- Returns:
src_key_padding_mask (torch.Tensor) – 源键填充掩码
tgt_key_padding_mask (torch.Tensor) – 目标键填充掩码
src_mask (torch.Tensor) – 源掩码
tgt_mask (torch.Tensor) – 目标掩码