speechbrain.lobes.models.huggingface_transformers.nllb 模块

该模块支持集成huggingface预训练的NLLB模型。 参考: https://arxiv.org/abs/2207.04672

需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html

Authors
  • Ha Nguyen 2023

摘要

类:

NLLB

该模块支持集成HuggingFace和SpeechBrain预训练的NLLB模型。

参考

class speechbrain.lobes.models.huggingface_transformers.nllb.NLLB(source, save_path, freeze=True, target_lang='fra_Latn', decoder_only=True, share_input_output_embed=True)[source]

基础:mBART

该模块支持集成HuggingFace和SpeechBrain预训练的NLLB模型。

源论文 NLLB: https://arxiv.org/abs/2207.04672 需要安装来自 HuggingFace 的 Transformer: https://huggingface.co/transformers/installation.html

该模型通常用作seq2seq模型的文本解码器。它将自动从HuggingFace下载模型或使用本地路径。

目前,HuggingFace的NLLB模型可以使用与mBART模型完全相同的代码加载。 因此,NLLB可以很好地继承mBART类。

Parameters:
  • 来源 (str) – HuggingFace 中心名称:例如 “facebook/nllb-200-1.3B”

  • save_path (str) – 下载模型的路径(目录)。

  • freeze (bool (默认值: True)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。

  • target_lang (str (默认: fra_Latn (也称为法语)) – 根据NLLB模型的目标语言代码。

  • decoder_only (bool (默认值: True)) – 如果为True,则只取模型的解码器部分(和/或lm_head)。 这在想要将预训练的语音编码器(例如wav2vec)与基于文本的预训练解码器(例如mBART, NLLB)结合时非常有用。

  • share_input_output_embed (bool (默认值: True)) – 如果为True,使用嵌入层作为lm_head。

Example

>>> import torch
>>> src = torch.rand([10, 1, 1024])
>>> tgt = torch.LongTensor([[256057,    313,     25,    525,    773,  21525,   4004,      2]])
>>> model_hub = "facebook/nllb-200-distilled-600M"
>>> save_path = "savedir"
>>> model = NLLB(model_hub, save_path)
>>> outputs = model(src, tgt)