speechbrain.inference.interfaces 模块
定义了用于使用预训练模型进行简单推理的接口
- Authors:
阿库·罗赫 2021
彼得·普兰廷加 2021
洛伦·卢戈斯奇 2020
Mirco Ravanelli 2020
Titouan Parcollet 2021
阿卜杜勒·赫巴 2021
安德烈亚斯·诺茨 2022, 2023
Pooneh Mousavi 2023
Sylvain de Langen 2023
阿德尔·穆门 2023
普拉迪亚·坎达尔卡 2023
摘要
类:
一个用于预训练模型的混入类,使得可以指定编码管道和解码管道 |
|
使用训练好的模型对新数据进行预测。 |
函数:
从外部源获取并加载一个接口 |
参考
- speechbrain.inference.interfaces.foreign_class(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', classname='CustomInterface', overrides={}, overrides_must_match=True, savedir=None, use_auth_token=False, download_only=False, huggingface_cache_dir=None, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]
从外部源获取并加载一个接口
源可以是文件系统上的位置或在线/huggingface
pymodule 文件应包含一个具有给定类名的类。返回该类的实例。目的是在文件中有一个自定义的 Pretrained 子类。在加载 Hyperparams YAML 文件之前,pymodule 文件也会被添加到 python 路径中,因此它可以包含所需的任何自定义实现。
超参数文件应包含一个“modules”键,这是一个用于计算的torch模块的字典。
超参数文件应包含一个“pretrainer”键,它是speechbrain.utils.parameter_transfer.Pretrainer
- Parameters:
source (str 或 Path 或 FetchSource) – 用于查找模型的位置。详情请参见
speechbrain.utils.fetching.fetch。hparams_file (str) – 用于构建推理所需模块的超参数文件的名称。必须包含两个键:“modules”和“pretrainer”,如所述。
pymodule_file (str) – 应该获取的Python文件的名称。
classname (str) – 类的名称,创建并返回该类的实例
overrides (dict) – 加载hparams文件时要进行的任何更改。
overrides_must_match (bool) – 当覆盖项与yaml_stream中的相应键不匹配时,是否抛出错误。
savedir (str 或 Path) – 预训练材料的存放位置。如果未指定,则使用缓存。
use_auth_token (bool (默认值: False)) – 如果为真,将使用Huggingface的auth_token从HuggingFace Hub加载私有模型, 默认值为False,因为大多数模型是公开的。
download_only (bool (默认值: False)) – 如果为真,则跳过类和实例的创建。
huggingface_cache_dir (str) – HuggingFace缓存路径;如果为None -> “~/.cache/huggingface” (默认: None)
local_strategy (speechbrain.utils.fetching.LocalStrategy) – 使用的获取策略,控制远程文件获取的行为,包括符号链接和复制。 有关更多详细信息,请参见
speechbrain.utils.fetching.fetch()。**kwargs (dict) – 传递给类构造函数的参数。
- Returns:
从给定的pymodule文件中获取具有给定类名的类的实例。
- Return type:
- class speechbrain.inference.interfaces.Pretrained(modules=None, hparams=None, run_opts=None, freeze_params=True)[source]
基础:
Module使用训练好的模型对新数据进行预测。
这是一个处理一些常见样板代码的基类。 它故意有一个类似于
Brain的接口 - 这些基类处理类似的事情。Pretrained 的子类应实现预训练系统运行的实际逻辑,并添加具有描述性名称的方法(例如,用于 ASR 的 transcribe_file())。
Pretrained 是一个 torch.nn.Module,因此像 .to() 或 .eval() 这样的方法可以工作。子类应提供一个合适的 forward() 实现:按照惯例,它应该是一个接收一批音频信号并运行完整模型(如适用)的方法。
- Parameters:
模块 (字典的字符串:torch.nn.Module 对) – 组成学习系统的 Torch 模块。这些模块可以以特殊方式处理(放置在正确的设备上、冻结等)。这些模块可以作为
self.mods下的属性使用,例如 self.mods.model(x)hparams (dict) – 每个键值对应由一个字符串键和一个在重写方法中使用的超参数组成。这些可以通过
hparams属性访问,使用“点”表示法:例如,self.hparams.model(x)。run_opts (dict) –
从命令行解析的选项。参见
speechbrain.parse_arguments()。 这里支持的列表:设备
数据并行计数
数据并行后端
分布式启动
分布式后端
即时编译
即时编译模块键
编译
编译模块键
编译模式
使用全图编译
使用动态形状跟踪编译
freeze_params (bool) – 是否冻结(requires_grad=False)参数。通常在推理时你会想要冻结参数。同时也会在所有模块上调用 .eval()。
- HPARAMS_NEEDED = []
- MODULES_NEEDED = []
- load_audio(path, savedir=None)[source]
使用此模型的输入规范加载音频文件
使用语音模型时,重要的是使用与训练模型时相同类型的数据。这意味着例如使用相同的采样率和通道数。然而,可以将文件从较高的采样率转换为较低的采样率(下采样)。同样,将立体声文件下混为单声道也很简单。路径可以是本地路径、网页URL或指向huggingface仓库的链接。
- classmethod from_hparams(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', overrides={}, savedir=None, use_auth_token=False, revision=None, download_only=False, huggingface_cache_dir=None, overrides_must_match=True, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]
根据HyperPyYAML文件从外部源获取并加载
源可以是文件系统上的位置或在线/huggingface
你可以使用 pymodule_file 来包含任何需要的自定义实现:如果该文件存在,那么在加载 Hyperparams YAML 之前,它的位置会被添加到 sys.path 中,因此可以在 YAML 中引用它。
超参数文件应包含一个“modules”键,这是一个用于计算的torch模块的字典。
超参数文件应包含一个“pretrainer”键,它是speechbrain.utils.parameter_transfer.Pretrainer
- Parameters:
source (str) – 用于查找模型的位置。详情请参见
speechbrain.utils.fetching.fetch。hparams_file (str) – 用于构建推理所需模块的超参数文件的名称。必须包含两个键:“modules”和“pretrainer”,如所述。
pymodule_file (str) – 可以获取一个Python文件。这允许包含任何自定义实现。在加载超参数YAML文件之前,文件的位置会被添加到sys.path中,因此可以在YAML中引用它。 这是可选的,但有一个默认值:“custom.py”。如果找不到默认文件,这将被忽略,但如果你提供了一个不同的文件名,那么在找不到文件的情况下会引发错误。
overrides (dict) – 加载hparams文件时要进行的任何更改。
savedir (str 或 Path) – 预训练材料的存放位置。如果未指定,则使用缓存。
use_auth_token (bool (默认值: False)) – 如果为真,将使用Huggingface的auth_token从HuggingFace Hub加载私有模型, 默认值为False,因为大多数模型是公开的。
revision (str) – 与HuggingFace Hub模型版本对应的模型修订版本。 如果您希望将代码固定到HuggingFace上托管的特定模型版本,这将特别有用。
download_only (bool (默认值: False)) – 如果为真,则跳过类和实例的创建。
huggingface_cache_dir (str) – HuggingFace缓存路径;如果为None -> “~/.cache/huggingface” (默认: None)
overrides_must_match (bool) – 覆盖项是否必须与文件中已有的参数匹配。
local_strategy (LocalStrategy, optional) – 使用哪种策略在本地处理文件。(默认:
LocalStrategy.SYMLINK)**kwargs (dict) – 传递给类构造函数的参数。
- Return type:
cls的实例
- class speechbrain.inference.interfaces.EncodeDecodePipelineMixin[source]
基础类:
object一个用于预训练模型的mixin,使得可以指定编码管道和解码管道
- property input_use_padded_data
如果开启,原始的PaddedData实例将被传递给模型。如果关闭,仅使用.data。
- Returns:
result – 是否直接使用填充数据
- Return type:
- property batch_outputs
确定输出管道是在批次上操作还是在单个示例上操作(true 表示批处理)
- Returns:
batch_outputs
- Return type: