speechbrain.lobes.models.huggingface_transformers.weighted_ssl 模块
该模块支持集成huggingface预训练的wav2vec2模型。
参考: https://arxiv.org/abs/2006.11477 参考: https://arxiv.org/abs/1904.05862 参考: https://arxiv.org/abs/2110.13900 需要安装来自HuggingFace的Transformer: https://huggingface.co/transformers/installation.html
- Authors
萨拉·扎伊姆 2023
阿德尔·穆门 2023, 2024
摘要
类:
该模块允许在SSL编码器中集成使用来自不同层的加权和表示。 |
参考
- class speechbrain.lobes.models.huggingface_transformers.weighted_ssl.WeightedSSLModel(hub, save_path='', layernorm=False, freeze=False, **kwargs)[source]
-
这个叶节点使得在SSL编码器中能够整合使用来自不同层的加权和表示。
该模型可用作SSL基准测试的固定特征提取器。它将自动从HuggingFace下载模型或使用本地路径。
更多详情请参见 recipes/SSL_benchmark
- Parameters:
Example
>>> inputs = torch.rand([10, 600]) >>> model_hub = "facebook/wav2vec2-base-960h" >>> save_path = "savedir" >>> model = WeightedSSLModel(model_hub, save_path) >>> outputs = model(inputs)