speechbrain.lobes.models.huggingface_transformers.weighted_ssl 模块

该模块支持集成huggingface预训练的wav2vec2模型。

参考: https://arxiv.org/abs/2006.11477 参考: https://arxiv.org/abs/1904.05862 参考: https://arxiv.org/abs/2110.13900 需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html

Authors
  • 萨拉·扎伊姆 2023

  • 阿德尔·穆门 2023, 2024

摘要

类:

WeightedSSLModel

该模块允许在SSL编码器中集成使用来自不同层的加权和表示。

参考

class speechbrain.lobes.models.huggingface_transformers.weighted_ssl.WeightedSSLModel(hub, save_path='', layernorm=False, freeze=False, **kwargs)[source]

基础类: HFTransformersInterface

这个叶节点使得在SSL编码器中能够整合使用来自不同层的加权和表示。

该模型可用作SSL基准测试的固定特征提取器。它将自动从HuggingFace下载模型或使用本地路径。

更多详情请参见 recipes/SSL_benchmark

Parameters:
  • hub (str) – HuggingFace 中心名称:例如 “facebook/wav2vec2-large-lv60”

  • save_path (str) – 下载模型的路径(目录)。

  • layernorm (bool, (默认值: False)) – 是否在求和之前对层表示进行层归一化

  • freeze (bool (默认值: True)) – 如果为True,模型将被冻结。如果为False,模型将与管道的其余部分一起训练。

  • **kwargs (dict) – 传递给 HFTransformersInterface 的额外参数

Example

>>> inputs = torch.rand([10, 600])
>>> model_hub = "facebook/wav2vec2-base-960h"
>>> save_path = "savedir"
>>> model = WeightedSSLModel(model_hub, save_path)
>>> outputs = model(inputs)
forward(wav, wav_lens=None)[source]

该方法输出SSL编码器层表示的加权和

Parameters:
  • wav (torch.Tensor) – 波形

  • wav_lens (torch.Tensor) – 音频长度

Returns:

weighted_feats – 层表示的加权和。

Return type:

torch.Tensor

override_config(config)[source]

如果需要覆盖配置,这里是地方

Parameters:

config (Wav2Vec2Config) – 需要覆盖原始配置。

Return type:

被覆盖的配置