speechbrain.utils.parameter_transfer 模块

为最简单的参数传递情况提供的便利函数。

使用 speechbrain.utils.checkpoints.Checkpointer 来查找检查点 和参数文件的路径。

Authors
  • 阿库·柔赫 2020

  • 安德烈亚斯·诺奇 2023

  • 阿德尔·穆门 2023

摘要

类:

Pretrainer

协调预训练

参考

class speechbrain.utils.parameter_transfer.Pretrainer(collect_in=None, loadables=None, paths=None, custom_hooks=None, conditions=None)[source]

基础类:object

协调预训练

首先,如果指定了,可以选择性地从某些来源(本地目录、HuggingFace 仓库、基本 URL)收集文件到 collect_in 目录中。

然后,为每个文件调用加载钩子。

Parameters:
  • collect_in (strPath, 可选) – 文件收集的目标目录路径。 如果 None,则文件将从缓存中引用或直接引用(如果可能的话,URL 将失败)。不会有一个包含所有文件的集中目标目录。

  • loadables (mapping) – 从可加载键到对象的映射。这将键连接到实际的对象实例。

  • paths (mapping) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像Huggingface hub ID这样的魔法源。 例如:sb/asr-crdnn-libri/lm.ckpt -> 源=sb/asr-crdnn-libri, 文件=lm.ckpt 请注意,在收集时,您可以指定一个默认源,用于所有未指定路径的可加载项。

  • custom_hooks (mapping) – 从可加载键到参数传输钩子函数的映射。如果你想使用自定义加载函数,请在此处指定。

  • conditions (mapping) – 一个可选的映射,从可加载的键到条件值,仅在某个标志打开时加载某些元素时有用

set_collect_in(path)[source]

更改收集路径

add_loadables(loadables)[source]

从给定的映射更新可加载字典。

Parameters:

loadables (mapping) – 从可加载键到对象的映射

add_paths(paths)[source]

更新不同可加载项的路径。

在收集参数时,这里的路径是首选。请注意,在收集时,您可以指定一个默认源,该源用于所有未指定路径的可加载项。

Parameters:

paths (mapping) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像Huggingface hub ID这样的魔法源。 例如:sb/asr-crdnn-libri/lm.ckpt -> 源=sb/asr-crdnn-libri, 文件=lm.ckpt

add_custom_hooks(custom_hooks)[source]

更新自定义钩子。

在加载参数时,这里的钩子优先于类默认值。

Parameters:

custom_hooks (mapping) – 从可加载键到参数传输钩子函数的映射。如果你想使用自定义加载函数,请在此处指定。

add_conditions(conditions)[source]

更新条件。

Parameters:

conditions (mapping) – 从可加载键到条件值的映射,仅在标志开启时加载某些元素时有用

static split_path(path)[source]

将路径拆分为源路径和文件名

除了常规路径外,这也处理URL和Huggingface hub路径。

Parameters:

路径 (str)

Returns:

  • str – 源

  • str – 文件名

collect_files(default_source=None, use_auth_token=False, local_strategy: LocalStrategy = LocalStrategy.SYMLINK)[source]

从已知路径获取参数,并使用默认源作为回退

实际的参数文件可能存放在其他地方,但这确保了在self.collect_in目录中有一个符号链接。符号链接始终使用文件名中的可加载键。这种标准化使得在例如分布式设置上协调预训练变得更加容易。

如果您将所有内容整齐地组织在一个位置,例如 Huggingface hub 仓库,请使用 default_source。

Parameters:
  • default_source (strPathFetchSource) – 这用于每个尚未指定路径的可加载项。 例如,如果可加载项的键为 "asr",则要查找的文件是 /asr.ckpt

  • use_auth_token (bool (默认值: False)) – 如果为真,将使用Huggingface的auth_token从HuggingFace Hub加载私有模型, 默认值为False,因为大多数模型是公开的。

  • local_strategy (speechbrain.utils.fetching.LocalStrategy) – 使用的获取策略,控制远程文件获取的行为,涉及符号链接和复制。 如果未指定collect_in目录,则忽略此参数。 有关更多详细信息,请参见speechbrain.utils.fetching.fetch()

Returns:

从可加载键到本地路径的映射,可以从该路径加载可加载的参数。这个类中没有使用,但可能有用。

Return type:

dict

is_loadable(name)[source]

如果没有定义条件或对于指定的可加载项,或者条件为真,则返回True

Parameters:

name (str) – 可加载项的名称

Returns:

is_loadable – 是否应加载该项目

Return type:

bool

load_collected()[source]

加载已收集的文件。