speechbrain.lobes.models.huggingface_transformers.mbart 模块
该模块支持集成huggingface预训练的mBART模型。 参考: https://arxiv.org/abs/2001.08210
需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
- Authors
Ha Nguyen 2023
摘要
类:
该模块支持集成HuggingFace和SpeechBrain预训练的mBART模型。 |
参考
- class speechbrain.lobes.models.huggingface_transformers.mbart.mBART(source, save_path, freeze=True, target_lang='fr_XX', decoder_only=True, share_input_output_embed=True)[source]
-
这个模块支持集成HuggingFace和SpeechBrain预训练的mBART模型。
源论文 mBART: https://arxiv.org/abs/2001.08210 需要安装来自 HuggingFace 的 Transformer: https://huggingface.co/transformers/installation.html
该模型通常用作seq2seq模型的文本解码器。它将自动从HuggingFace下载模型或使用本地路径。
- Parameters:
source (str) – HuggingFace 中心名称:例如 “facebook/mbart-large-50-many-to-many-mmt”
save_path (str) – 下载模型的路径(目录)。
freeze (bool (默认值: True)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。
target_lang (str (默认: fra_Latn (也称为法语)) – 根据NLLB模型的目标语言代码。
decoder_only (bool (默认值: True)) – 如果为True,则只取模型的解码器部分(和/或lm_head)。 这在想要将预训练的语音编码器(例如wav2vec)与基于文本的预训练解码器(例如mBART, NLLB)结合时非常有用。
share_input_output_embed (bool (默认值: True)) – 如果为True,使用嵌入层作为lm_head。
Example
>>> src = torch.rand([10, 1, 1024]) >>> tgt = torch.LongTensor([[250008, 313, 25, 525, 773, 21525, 4004, 2]]) >>> model_hub = "facebook/mbart-large-50-many-to-many-mmt" >>> save_path = "savedir" >>> model = mBART(model_hub, save_path) >>> outputs = model(src, tgt)
- forward(src, tgt, pad_idx=0)[source]
此方法使用wav2vec编码器为mt任务实现前向步骤 (与上述相同,但没有编码器堆栈)
- Parameters:
src (tensor) – 来自w2v2编码器的输出特征(转录)
tgt (tensor) – 解码器的目标序列(翻译)(必需)。
pad_idx (int) – <pad> 标记的索引(默认=0)。
- Returns:
dec_out – 解码器输出。
- Return type:
torch.Tensor
- decode(tgt, encoder_out, enc_len=None)[source]
该方法实现了变压器模型的解码步骤。
- Parameters:
tgt (torch.Tensor) – 解码器的输入序列。
encoder_out (torch.Tensor) – 编码器的隐藏输出。
enc_len (torch.LongTensor) – 编码器状态的实际长度。
- Returns:
output (torch.Tensor) – 变换器的输出。
cross_attention (torch.Tensor) – 注意力值。