speechbrain.lobes.models.huggingface_transformers.mbart 模块

该模块支持集成huggingface预训练的mBART模型。 参考: https://arxiv.org/abs/2001.08210

需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html

Authors
  • Ha Nguyen 2023

摘要

类:

mBART

该模块支持集成HuggingFace和SpeechBrain预训练的mBART模型。

参考

class speechbrain.lobes.models.huggingface_transformers.mbart.mBART(source, save_path, freeze=True, target_lang='fr_XX', decoder_only=True, share_input_output_embed=True)[source]

基础类: HFTransformersInterface

这个模块支持集成HuggingFace和SpeechBrain预训练的mBART模型。

源论文 mBART: https://arxiv.org/abs/2001.08210 需要安装来自 HuggingFace 的 Transformer: https://huggingface.co/transformers/installation.html

该模型通常用作seq2seq模型的文本解码器。它将自动从HuggingFace下载模型或使用本地路径。

Parameters:
  • source (str) – HuggingFace 中心名称:例如 “facebook/mbart-large-50-many-to-many-mmt”

  • save_path (str) – 下载模型的路径(目录)。

  • freeze (bool (默认值: True)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。

  • target_lang (str (默认: fra_Latn (也称为法语)) – 根据NLLB模型的目标语言代码。

  • decoder_only (bool (默认值: True)) – 如果为True,则只取模型的解码器部分(和/或lm_head)。 这在想要将预训练的语音编码器(例如wav2vec)与基于文本的预训练解码器(例如mBART, NLLB)结合时非常有用。

  • share_input_output_embed (bool (默认值: True)) – 如果为True,使用嵌入层作为lm_head。

Example

>>> src = torch.rand([10, 1, 1024])
>>> tgt = torch.LongTensor([[250008,    313,     25,    525,    773,  21525,   4004,      2]])
>>> model_hub = "facebook/mbart-large-50-many-to-many-mmt"
>>> save_path = "savedir"
>>> model = mBART(model_hub, save_path) 
>>> outputs = model(src, tgt) 
forward(src, tgt, pad_idx=0)[source]

此方法使用wav2vec编码器为mt任务实现前向步骤 (与上述相同,但没有编码器堆栈)

Parameters:
  • src (tensor) – 来自w2v2编码器的输出特征(转录)

  • tgt (tensor) – 解码器的目标序列(翻译)(必需)。

  • pad_idx (int) – <pad> 标记的索引(默认=0)。

Returns:

dec_out – 解码器输出。

Return type:

torch.Tensor

decode(tgt, encoder_out, enc_len=None)[source]

该方法实现了变压器模型的解码步骤。

Parameters:
  • tgt (torch.Tensor) – 解码器的输入序列。

  • encoder_out (torch.Tensor) – 编码器的隐藏输出。

  • enc_len (torch.LongTensor) – 编码器状态的实际长度。

Returns:

  • output (torch.Tensor) – 变换器的输出。

  • cross_attention (torch.Tensor) – 注意力值。

custom_padding(x, org_pad, custom_pad)[source]

此方法自定义填充。 SpeechBrain 的默认 pad_idx 为 0。 然而,有些基于文本的模型(如 mBART)将 0 保留用于其他用途, 并使用特定的 pad_idx 进行训练。 此方法将 org_pad 更改为 custom_pad

Parameters:
  • x (torch.Tensor) – 输入张量,带有原始的 pad_idx

  • org_pad (int) – 原始 pad_idx

  • custom_pad (int) – 自定义的pad_idx

Returns:

out – 填充后的输出。

Return type:

torch.Tensor

override_config(config)[source]

如果需要覆盖配置,这里是地方。

Parameters:

config (MBartConfig) – 需要覆盖原始配置。

Return type:

被覆盖的配置