speechbrain.lobes.models.huggingface_transformers.nllb 模块
该模块支持集成huggingface预训练的NLLB模型。 参考: https://arxiv.org/abs/2207.04672
需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
- Authors
Ha Nguyen 2023
摘要
类:
该模块支持集成HuggingFace和SpeechBrain预训练的NLLB模型。 |
参考
- class speechbrain.lobes.models.huggingface_transformers.nllb.NLLB(source, save_path, freeze=True, target_lang='fra_Latn', decoder_only=True, share_input_output_embed=True)[source]
基础:
mBART该模块支持集成HuggingFace和SpeechBrain预训练的NLLB模型。
源论文 NLLB: https://arxiv.org/abs/2207.04672 需要安装来自 HuggingFace 的 Transformer: https://huggingface.co/transformers/installation.html
该模型通常用作seq2seq模型的文本解码器。它将自动从HuggingFace下载模型或使用本地路径。
目前,HuggingFace的NLLB模型可以使用与mBART模型完全相同的代码加载。 因此,NLLB可以很好地继承mBART类。
- Parameters:
来源 (str) – HuggingFace 中心名称:例如 “facebook/nllb-200-1.3B”
save_path (str) – 下载模型的路径(目录)。
freeze (bool (默认值: True)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。
target_lang (str (默认: fra_Latn (也称为法语)) – 根据NLLB模型的目标语言代码。
decoder_only (bool (默认值: True)) – 如果为True,则只取模型的解码器部分(和/或lm_head)。 这在想要将预训练的语音编码器(例如wav2vec)与基于文本的预训练解码器(例如mBART, NLLB)结合时非常有用。
share_input_output_embed (bool (默认值: True)) – 如果为True,使用嵌入层作为lm_head。
Example
>>> import torch >>> src = torch.rand([10, 1, 1024]) >>> tgt = torch.LongTensor([[256057, 313, 25, 525, 773, 21525, 4004, 2]]) >>> model_hub = "facebook/nllb-200-distilled-600M" >>> save_path = "savedir" >>> model = NLLB(model_hub, save_path) >>> outputs = model(src, tgt)