speechbrain.utils.parameter_transfer 模块
为最简单的参数传递情况提供的便利函数。
使用 speechbrain.utils.checkpoints.Checkpointer 来查找检查点
和参数文件的路径。
- Authors
阿库·柔赫 2020
安德烈亚斯·诺奇 2023
阿德尔·穆门 2023
摘要
类:
协调预训练 |
参考
- class speechbrain.utils.parameter_transfer.Pretrainer(collect_in=None, loadables=None, paths=None, custom_hooks=None, conditions=None)[source]
基础类:
object协调预训练
首先,如果指定了,可以选择性地从某些来源(本地目录、HuggingFace 仓库、基本 URL)收集文件到
collect_in目录中。然后,为每个文件调用加载钩子。
- Parameters:
collect_in (str 或 Path, 可选) – 文件收集的目标目录路径。 如果
None,则文件将从缓存中引用或直接引用(如果可能的话,URL 将失败)。不会有一个包含所有文件的集中目标目录。loadables (mapping) – 从可加载键到对象的映射。这将键连接到实际的对象实例。
paths (mapping) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像Huggingface hub ID这样的魔法源。 例如:sb/asr-crdnn-libri/lm.ckpt -> 源=sb/asr-crdnn-libri, 文件=lm.ckpt 请注意,在收集时,您可以指定一个默认源,用于所有未指定路径的可加载项。
custom_hooks (mapping) – 从可加载键到参数传输钩子函数的映射。如果你想使用自定义加载函数,请在此处指定。
conditions (mapping) – 一个可选的映射,从可加载的键到条件值,仅在某个标志打开时加载某些元素时有用
- add_paths(paths)[source]
更新不同可加载项的路径。
在收集参数时,这里的路径是首选。请注意,在收集时,您可以指定一个默认源,该源用于所有未指定路径的可加载项。
- Parameters:
paths (mapping) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像Huggingface hub ID这样的魔法源。 例如:sb/asr-crdnn-libri/lm.ckpt -> 源=sb/asr-crdnn-libri, 文件=lm.ckpt
- add_custom_hooks(custom_hooks)[source]
更新自定义钩子。
在加载参数时,这里的钩子优先于类默认值。
- Parameters:
custom_hooks (mapping) – 从可加载键到参数传输钩子函数的映射。如果你想使用自定义加载函数,请在此处指定。
- add_conditions(conditions)[source]
更新条件。
- Parameters:
conditions (mapping) – 从可加载键到条件值的映射,仅在标志开启时加载某些元素时有用
- static split_path(path)[source]
将路径拆分为源路径和文件名
除了常规路径外,这也处理URL和Huggingface hub路径。
- Parameters:
路径 (str)
- Returns:
str – 源
str – 文件名
- collect_files(default_source=None, use_auth_token=False, local_strategy: LocalStrategy = LocalStrategy.SYMLINK)[source]
从已知路径获取参数,并使用默认源作为回退
实际的参数文件可能存放在其他地方,但这确保了在self.collect_in目录中有一个符号链接。符号链接始终使用文件名中的可加载键。这种标准化使得在例如分布式设置上协调预训练变得更加容易。
如果您将所有内容整齐地组织在一个位置,例如 Huggingface hub 仓库,请使用 default_source。
- Parameters:
default_source (str 或 Path 或 FetchSource) – 这用于每个尚未指定路径的可加载项。 例如,如果可加载项的键为
"asr",则要查找的文件是/asr.ckpt use_auth_token (bool (默认值: False)) – 如果为真,将使用Huggingface的auth_token从HuggingFace Hub加载私有模型, 默认值为False,因为大多数模型是公开的。
local_strategy (speechbrain.utils.fetching.LocalStrategy) – 使用的获取策略,控制远程文件获取的行为,涉及符号链接和复制。 如果未指定
collect_in目录,则忽略此参数。 有关更多详细信息,请参见speechbrain.utils.fetching.fetch()。
- Returns:
从可加载键到本地路径的映射,可以从该路径加载可加载的参数。这个类中没有使用,但可能有用。
- Return type: