speechbrain.utils.checkpoints 模块
该模块实现了一个检查点保存器和加载器。
实验中的检查点通常需要保存许多不同事物的状态:模型参数、优化器参数、当前是哪个周期等。检查点的保存格式是一个目录,其中每个可保存的事物都有自己的文件。此外,一个特殊的文件保存有关检查点的元信息(默认情况下只是创建时间,但您可以指定您可能希望的任何其他内容,例如验证损失)。
检查点系统的接口要求您指定要保存的内容。这种方法灵活且不依赖于您的实验实际运行的方式。
接口要求您为每个要保存的内容指定名称。这个名称用于在恢复时将正确的参数文件提供给正确的对象。
默认的保存和加载方法仅适用于torch.nn.Modules(及其子类)和torch.optim.Optimizers。如果这些方法不适用于您的对象,您可以为特定实例或类指定自己的保存和/或加载方法。
Example
>>> # Toy example Module:
>>> class Recoverable(torch.nn.Module):
... def __init__(self, param):
... super().__init__()
... self.param = torch.nn.Parameter(torch.tensor([param]))
... def forward(self, x):
... return x * self.param
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> # In simple cases, the module aims to have a terse syntax,
>>> # consisting of three steps.
>>> # 1. Specifying where to save checkpoints and what is included in a
>>> # checkpoint:
>>> checkpointer = Checkpointer(tempdir, {"network": model})
>>> # 2. Recover from the latest checkpoint, if one is found:
>>> checkpointer.recover_if_possible()
>>> # Run your experiment:
>>> data = [(0.1, 0.9), (0.3, 0.8)]
>>> for example, target in data:
... loss = (model(example) - target)**2
... # 3. Save checkpoints, and keep by default just one, the newest:
... ckpt = checkpointer.save_and_keep_only()
- Authors
阿库·柔赫 2020
阿德尔·穆门 2024
摘要
类:
描述一个已保存检查点的NamedTuple |
|
保存检查点并从检查点恢复。 |
函数:
从多个检查点中平均参数。 |
|
从state_dicts的迭代器中生成一个平均的state_dict。 |
|
以最近性作为检查点重要性指标。 |
|
查找与给定对象一起使用的默认保存/加载钩子。 |
|
加载状态字典检查点时调用的钩子。 |
|
根据提供的映射关系映射旧状态字典中的键。 |
|
方法装饰器,将给定方法标记为检查点加载钩子。 |
|
方法装饰器,将给定方法标记为检查点保存钩子。 |
|
方法装饰器,将给定方法标记为参数传递钩子。 |
|
类装饰器,用于注册加载、保存和传输钩子。 |
|
非严格的Torch模块状态字典加载。 |
|
从给定路径加载 |
|
从给定路径立即加载torch.nn.Module的state_dict。 |
|
将obj的参数保存到路径。 |
参考
- speechbrain.utils.checkpoints.map_old_state_dict_weights(state_dict: Dict[str, Tensor], mapping: Dict[str, str]) Dict[str, Tensor][source]
根据提供的映射关系,将旧状态字典中的键进行映射。
注意:此函数将重新映射所有包含旧键的state_dict键。 例如,如果state_dict是{‘model.encoder.layer.0.atn.self.query.weight’: …} 并且映射是{‘.atn’: ‘.attn’},则生成的state_dict将是 {‘model.encoder.layer.0.attn.self.query.weight’: …}。
由于这实际上起到了大规模子字符串替换的作用,部分键匹配(例如在一个图层名称的中间)也会起作用,因此要小心避免误报。
- speechbrain.utils.checkpoints.hook_on_loading_state_dict_checkpoint(state_dict: Dict[str, Tensor]) Dict[str, Tensor][source]
加载状态字典检查点时要调用的钩子。
当加载 state_dict 检查点时调用此钩子。它可用于在将 state_dict 加载到模型之前对其进行修改。
默认情况下,此钩子会将旧的 state_dict 键映射到新的键。
- speechbrain.utils.checkpoints.torch_recovery(obj, path, end_of_epoch)[source]
从给定路径立即加载一个torch.nn.Module的state_dict。
这可以通过以下方式设置为 torch.nn.Modules 的默认值: >>> DEFAULT_LOAD_HOOKS[torch.nn.Module] = torch_recovery
- Parameters:
obj (torch.nn.Module) – 要加载参数的实例。
path (str, pathlib.Path) – 加载路径。
end_of_epoch (bool) – 恢复是否来自一个epoch结束时的检查点。
- speechbrain.utils.checkpoints.torch_patched_state_dict_load(path, device='cpu')[source]
从给定路径使用
torch.load()加载state_dict,并调用SpeechBrain的state_dict加载钩子,例如应用键名修补规则以确保兼容性。state_dict没有进一步的预处理,也不会应用到模型中,请参见torch_recovery()或torch_parameter_transfer()。- Parameters:
path (str, pathlib.Path) – 加载路径。
device (str) – 加载的
state_dict张量应驻留的设备。这是传递给torch.load()的;详情请参阅其文档。
- Return type:
加载的状态字典。
- speechbrain.utils.checkpoints.torch_save(obj, path)[source]
将对象的参数保存到路径。
torch.nn.Modules 的默认保存钩子 用于保存 torch.nn.Module 的 state_dicts。
- Parameters:
obj (torch.nn.Module) – 要保存的实例。
path (str, pathlib.Path) – 保存路径。
- speechbrain.utils.checkpoints.torch_parameter_transfer(obj, path)[source]
非严格的 Torch 模块状态字典加载。
从路径加载一组参数到对象。如果对象中有找不到参数的层,只会记录一个警告。同样,如果路径中有找不到对象中对应层的参数,也只会记录一个警告。
- Parameters:
obj (torch.nn.Module) – 要加载参数的实例。
path (str) – 加载路径。
- speechbrain.utils.checkpoints.mark_as_saver(method)[source]
方法装饰器,将给定方法标记为检查点保存钩子。
参见 register_checkpoint_hooks 的示例。
- Parameters:
method (可调用的) – 要装饰的类的方法。必须能够使用位置参数调用,签名格式为 (instance, path)。例如,以下方法满足此条件:def saver(self, path):
- Return type:
装饰后的方法,标记为检查点保存器。
注意
这不会添加钩子(通过方法装饰器不可能实现), 你还必须用@register_checkpoint_hooks装饰类 只能添加一个方法作为钩子。
- speechbrain.utils.checkpoints.mark_as_loader(method)[source]
方法装饰器,将给定方法标记为检查点加载钩子。
- Parameters:
method (callable) – 要装饰的类的方法。必须可以使用位置参数调用,签名格式为 (instance, path, end_of_epoch)。例如:
def loader(self, path, end_of_epoch):- Return type:
装饰后的方法,注册为检查点加载器。
注意
这不会添加钩子(通过方法装饰器不可能实现), 你还必须用@register_checkpoint_hooks装饰类 只能添加一个方法作为钩子。
- speechbrain.utils.checkpoints.mark_as_transfer(method)[source]
方法装饰器,将给定方法标记为参数传递钩子。
- Parameters:
method (callable) – 要装饰的类的方法。必须可以使用位置参数调用,签名应为 (instance, path)。例如:
def loader(self, path):- Return type:
装饰后的方法,注册为转移方法。
注意
这不会添加钩子(通过方法装饰器不可能实现), 你还必须用@register_checkpoint_hooks装饰类 只能添加一个方法作为钩子。
注意
传输钩子优先于加载器钩子,由
Pretrainer决定。然而,如果没有注册传输钩子,Pretrainer将使用加载器钩子。
- speechbrain.utils.checkpoints.register_checkpoint_hooks(cls, save_on_main_only=True)[source]
类装饰器,用于注册加载、保存和传输钩子。
钩子必须已经用 mark_as_loader 和 mark_as_saver 标记,可能还有 mark_as_transfer。
- Parameters:
cls (class) – 要装饰的类
save_on_main_only (bool) – 默认情况下,保存器仅在单个进程上运行。此参数提供了在所有进程上运行保存器的选项,这对于某些需要先收集数据再保存的保存器是必要的。
- Return type:
带有注册钩子的装饰类
Example
>>> @register_checkpoint_hooks ... class CustomRecoverable: ... def __init__(self, param): ... self.param = int(param) ... ... @mark_as_saver ... def save(self, path): ... with open(path, "w", encoding="utf-8") as fo: ... fo.write(str(self.param)) ... ... @mark_as_loader ... def load(self, path, end_of_epoch): ... del end_of_epoch # Unused here ... with open(path, encoding="utf-8") as fi: ... self.param = int(fi.read())
- speechbrain.utils.checkpoints.get_default_hook(obj, default_hooks)[source]
找到与给定对象一起使用的默认保存/加载钩子。
遵循方法解析顺序,即如果没有为对象本身的类注册钩子,也会搜索对象继承的类。
- Parameters:
obj (实例) – 类的实例。
default_hooks (dict) – 从类到(检查点钩子)函数的映射。
- Return type:
如果没有注册方法,则返回正确的方法或None。
Example
>>> a = torch.nn.Module() >>> get_default_hook(a, DEFAULT_SAVE_HOOKS) == torch_save True
- class speechbrain.utils.checkpoints.Checkpoint(path, meta, paramfiles)
基础:
tuple描述一个已保存检查点的NamedTuple
要从多个检查点中选择一个加载,首先根据这个命名元组对检查点进行过滤和排序。检查点器将路径放在path中,并将字典放在meta中。在保存检查点时,您基本上可以向meta添加任何信息。meta中唯一的默认键是“unixtime”。Checkpoint.paramfiles是一个从可恢复名称到参数文件路径的字典。
- meta
字段编号1的别名
- paramfiles
字段编号2的别名
- path
字段编号 0 的别名
- speechbrain.utils.checkpoints.ckpt_recency(ckpt)[source]
最近性作为检查点重要性指标。
此函数还可以作为如何创建检查点重要性关键函数的示例。这是一个命名函数,但正如你所见,在紧急情况下可以很容易地将其实现为 lambda 函数。
- class speechbrain.utils.checkpoints.Checkpointer(checkpoints_dir, recoverables=None, custom_load_hooks=None, custom_save_hooks=None, allow_partial_load=False)[source]
基础类:
object保存检查点并从它们中恢复。
- Parameters:
checkpoints_dir (str, pathlib.Path) – 保存检查点的目录路径。
recoverables (mapping, optional) – 要恢复的对象。它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复对象。 该名称也用于对象参数的保存文件的文件名中。这些也可以通过 add_recoverable 或 add_recoverables 添加,或者直接修改 checkpointer.recoverables。
custom_load_hooks (mapping, optional) – 从名称[与recoverables中的相同]到函数或方法的映射。 为特定对象设置自定义加载钩子。函数/方法必须可以使用位置参数调用,签名应为(instance, path)。例如:
def loader(self, path)。custom_save_hooks (mapping, optional) – 从名称[与可恢复对象中的名称相同]到函数或方法的映射。 为特定对象设置自定义保存钩子。该函数/方法必须可以使用位置参数调用,签名应为(instance, path)。例如,以下函数满足此要求:def saver(self, path):
allow_partial_load (bool, 可选) – 如果为True,允许加载一个检查点,其中并非每个已注册的可恢复项都找到保存文件。在这种情况下,仅加载找到的保存文件。当为False时,加载此类保存文件将引发RuntimeError。(默认值:False)
Example
>>> import torch >>> #SETUP: >>> tempdir = getfixture('tmpdir') >>> class Recoverable(torch.nn.Module): ... def __init__(self, param): ... super().__init__() ... self.param = torch.nn.Parameter(torch.tensor([param])) ... def forward(self, x): ... return x * self.param >>> recoverable = Recoverable(1.) >>> recoverables = {'recoverable': recoverable} >>> # SETUP DONE. >>> checkpointer = Checkpointer(tempdir, recoverables) >>> first_ckpt = checkpointer.save_checkpoint() >>> recoverable.param.data = torch.tensor([2.]) >>> loaded_ckpt = checkpointer.recover_if_possible() >>> # Parameter has been loaded: >>> assert recoverable.param.data == torch.tensor([1.]) >>> # With this call, by default, oldest checkpoints are deleted: >>> checkpointer.save_and_keep_only() >>> assert first_ckpt not in checkpointer.list_checkpoints()
- add_recoverable(name, obj, custom_load_hook=None, custom_save_hook=None, optional_load=False)[source]
注册一个可恢复的对象,可能带有自定义钩子。
- Parameters:
name (str) – 可恢复对象的唯一名称。用于将保存文件映射到对象。
obj (实例) – 要恢复的对象。
custom_load_hook (callable, optional) – 用于加载对象的保存文件的函数。该函数/方法必须可以使用位置参数签名 (instance, path) 进行调用。例如:def load(self, path):
custom_save_hook (callable, optional) – 用于保存对象参数的回调函数。该函数/方法必须能够使用位置参数以签名 (instance, path) 调用。例如,以下函数满足此条件:def saver(self, path):
optional_load (bool, optional) – 如果为True,允许从检查点中可选地加载一个对象。如果检查点中缺少指定的对象,不会引发错误。这在不同的训练配置之间切换时特别有用,例如将精度从浮点32更改为16。例如,假设你有一个训练检查点,其中不包含
scaler对象。如果你打算在浮点16中继续预训练,而scaler对象是必需的,将其标记为可选可以防止加载错误。如果不将其标记为可选,尝试从浮点32训练的检查点加载scaler对象将会失败,因为该检查点中不存在scaler对象。
- add_recoverables(recoverables)[source]
从给定的映射中更新可恢复字典。
- Parameters:
recoverables (mapping) – 要恢复的对象。 它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复对象。该名称也用于对象参数的保存文件的文件名中。
- save_checkpoint(meta={}, end_of_epoch=True, name=None, verbosity=20)[source]
保存一个检查点。
整个检查点变成一个目录。 将每个注册对象的参数保存在单独的文件中。 还会添加一个元文件。默认情况下,元文件只包含 unixtime(自Unix纪元以来的秒数),但您可以自己添加任何 相关的内容。元信息稍后用于选择要加载的 检查点。
end_of_epoch 的值保存在 meta 中。这可能会影响 epoch 计数器和数据集迭代器加载它们的状态。
对于多进程保存,有些情况下我们可能希望在多个进程上运行保存代码(例如FSDP,我们需要在保存之前收集参数)。这是通过在主进程上创建一个保存文件夹并将其通信给所有进程,然后让每个保存/加载方法控制它应该在一个还是所有进程上保存来实现的。
- Parameters:
- Returns:
namedtuple [见上文],保存的检查点,除非这是在非主进程上运行,在这种情况下返回 None。
- Return type:
- save_and_keep_only(meta={}, end_of_epoch=True, name=None, num_to_keep=1, keep_recent=True, importance_keys=[], max_keys=[], min_keys=[], ckpt_predicate=None, verbosity=20)[source]
保存一个检查点,然后删除最不重要的检查点。
本质上,这在一个调用中结合了
save_checkpoint()和delete_checkpoints(),提供了简短的语法。- Parameters:
meta (mapping, optional) – 一个映射,它被添加到检查点中的元文件中。默认情况下包含键“unixtime”。
end_of_epoch (bool, optional) – 检查点是否在epoch结束时。默认为True。 可能会影响加载。
name (str, optional) – 为您的检查点指定一个自定义名称。 名称仍将添加前缀。如果未提供名称, 将从时间戳和随机唯一ID创建一个名称。
num_to_keep (int, 可选) – 要保留的检查点数量。默认为1。这将删除过滤后剩余的所有检查点。必须 >=0。
keep_recent (bool, 可选) – 是否保留最近的
num_to_keep检查点。importance_keys (list, optional) – 用于排序的关键函数列表(参见内置的sorted函数)。 每个可调用对象定义一个排序顺序,并为每个可调用对象保留num_to_keep个检查点。 保留具有最高键的检查点。 这些函数会传递Checkpoint命名元组(见上文)。
max_keys (list, 可选) – 一个键的列表,将保留这些键的最高值。
min_keys (list, 可选) – 一个键的列表,其中将保留最低的值。
ckpt_predicate (callable, optional) – 使用此选项可以从删除中排除一些检查点。在进行任何排序之前,检查点列表会通过此谓词进行过滤。只有那些
ckpt_predicate为True的检查点才能被删除。该函数使用Checkpoint命名元组(见上文)调用。verbosity (int) – 日志记录级别,默认为 logging.INFO
注意
与 save_checkpoint 不同,这不返回任何内容,因为我们无法保证保存的检查点实际上能够避免被删除。
- find_checkpoint(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[source]
从所有可用的检查点中选择一个特定的检查点。
如果未使用
importance_key、max_key和min_key中的任何一个,则将返回最近的检查点。它们中最多只能使用一个。大多数功能实际上是在
find_checkpoints()中实现的,但这里保留为一个有用的接口。- Parameters:
importance_key (callable, optional) – 用于排序的关键函数。 返回最高值的检查点将被选中。 该函数使用Checkpoint命名元组调用。
max_key (str, optional) – 将返回具有此键最高值的检查点。只有具有此键的检查点才会被考虑!
min_key (str, optional) – 将返回具有此键最低值的检查点。只有具有此键的检查点才会被考虑!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会使用此谓词进行过滤。 请参阅内置的filter函数。 该函数会使用Checkpoint命名元组(见上文)进行调用。 默认情况下,所有检查点都会被考虑。
- Returns:
检查点 – 如果找到。
无 – 如果过滤后没有检查点存在/剩余。
- find_checkpoints(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None, max_num_checkpoints=None)[source]
选择多个检查点。
如果未使用
importance_key、max_key和min_key中的任何一个,则将返回最近的检查点。这些参数中最多只能使用一个。- Parameters:
importance_key (callable, optional) – 用于排序的关键函数。 返回最高值的检查点将被选中。 该函数使用Checkpoint命名元组调用。
max_key (str, optional) – 将返回具有此键最高值的检查点。只有具有此键的检查点才会被考虑!
min_key (str, optional) – 将返回具有此键最低值的检查点。只有具有此键的检查点才会被考虑!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会使用此谓词进行过滤。 参见内置的filter函数。 该函数使用Checkpoint命名元组(见上文)调用。 默认情况下,所有检查点都会被考虑。
max_num_checkpoints (int, None) – 返回的最大检查点数量,或返回所有找到的检查点。
- Returns:
列表最多包含指定的最大数量的检查点。
- Return type:
- recover_if_possible(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[source]
选择一个检查点并从中恢复,如果找到的话。
如果未找到检查点,则不运行恢复。
如果未使用
importance_key、max_key和min_key中的任何一个,则将返回最近的检查点。它们中最多只能使用一个。- Parameters:
importance_key (callable, optional) – 用于排序的关键函数。 返回最高值的检查点将被加载。 该函数使用Checkpoint命名元组调用。
max_key (str, optional) – 将加载具有此键最高值的检查点。 只有具有此键的检查点才会被考虑!
min_key (str, optional) – 将加载具有此键最低值的检查点。 只有具有此键的检查点才会被考虑!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会使用此谓词进行过滤。 请参阅内置的filter函数。 该函数会使用Checkpoint命名元组(见上文)进行调用。 默认情况下,所有检查点都会被考虑。
- Returns:
检查点 – 如果找到。
无 – 如果过滤后没有检查点存在/剩余。
- load_checkpoint(checkpoint)[source]
加载指定的检查点。
- Parameters:
checkpoint (Checkpoint) – 要加载的检查点。
- delete_checkpoints(*, num_to_keep=1, min_keys=None, max_keys=None, importance_keys=[<function ckpt_recency>], ckpt_predicate=None, verbosity=20)[source]
删除最不重要的检查点。
由于定义重要性的方式可能有很多种(例如最低的WER,最低的损失),用户应提供一系列排序关键函数,每个函数定义一种特定的重要性顺序。本质上,每个重要性关键函数提取一个重要性指标(越高越重要)。对于这些顺序中的每一个,都会保留num_to_keep个检查点。然而,如果每个顺序保留的检查点之间存在重叠,则不会保留额外的检查点,因此保留的检查点总数可能少于:
num_to_keep * len(importance_keys)
- Parameters:
num_to_keep (int, optional) – 要保留的检查点数量。 默认为10。您可以选择保留0。这将删除过滤后剩余的所有检查点。必须 >=0
min_keys (list, optional) – 表示元数据中键的字符串列表。这些值中的最小值将被保留,最多保留num_to_keep个。
max_keys (list, optional) – 表示元数据中键的字符串列表。这些值中的最高值将被保留,最多保留 num_to_keep 个。
importance_keys (list, optional) – 用于排序的关键函数列表(参见内置的sorted函数)。 每个可调用对象定义一个排序顺序,并为每个可调用对象保留num_to_keep个检查点。 明确地说,保留具有最高键值的检查点。 这些函数使用Checkpoint命名元组调用(见上文)。另请参见默认值(ckpt_recency,上文)。默认情况下,删除除最新检查点之外的所有检查点。
ckpt_predicate (callable, optional) – 使用此选项可以从删除中排除一些检查点。在进行任何排序之前,检查点列表会通过此谓词进行过滤。只有那些
ckpt_predicate为True的检查点才能被删除。该函数使用Checkpoint命名元组(见上文)调用。verbosity (日志级别) – 设置此删除操作的日志级别。
注意
必须使用关键字参数调用,以表明你知道自己在做什么。删除是永久性的。
- speechbrain.utils.checkpoints.average_state_dicts(state_dicts)[source]
从state_dicts的迭代器中生成一个平均的state_dict。
请注意,此时会在内存中保留两个state_dicts,这是最低的内存要求。
- Parameters:
state_dicts (iterator, list) – 要平均的state_dicts。
- Returns:
平均的state_dict。
- Return type:
state_dict
- speechbrain.utils.checkpoints.average_checkpoints(checkpoint_list, recoverable_name, parameter_loader=<function load>, averager=<function average_state_dicts>)[source]
从多个检查点计算平均参数。
使用 Checkpointer.find_checkpoints() 获取要平均的检查点列表。 在训练过程中,对最后一些检查点的参数进行平均已被证明有时可以提高性能。
默认的加载器和平均器适用于标准的 PyTorch 模块。
- Parameters:
checkpoint_list (list) – 要平均的检查点列表。
recoverable_name (str) – 可恢复项的名称,其参数被加载并取平均值。
parameter_loader (function) – 一个函数,它接受一个参数,即参数文件的路径,并从该文件加载参数。默认情况下,使用torch.load,它会生成state_dict字典。
averager (function) – 一个函数,它接受一个迭代器,该迭代器遍历由parameter_loader加载的每个检查点的参数,并生成它们的平均值。 请注意,该函数是用迭代器调用的,因此最初长度是未知的;实现应该简单地计算生成的参数集的数量。有关示例,请参见上面的average_state_dicts。它是默认的平均器,并对state_dicts进行平均。
- Returns:
averager函数的输出。
- Return type:
任何
Example
>>> # Consider this toy Module again: >>> class Recoverable(torch.nn.Module): ... def __init__(self, param): ... super().__init__() ... self.param = torch.nn.Parameter(torch.tensor([param])) ... def forward(self, x): ... return x * self.param >>> # Now let's make some checkpoints: >>> model = Recoverable(1.) >>> tempdir = getfixture('tmpdir') >>> checkpointer = Checkpointer(tempdir, {"model": model}) >>> for new_param in range(10): ... model.param.data = torch.tensor([float(new_param)]) ... _ = checkpointer.save_checkpoint() # Suppress output with assignment >>> # Let's average the 3 latest checkpoints >>> # (parameter values 7, 8, 9 -> avg=8) >>> ckpt_list = checkpointer.find_checkpoints(max_num_checkpoints = 3) >>> averaged_state = average_checkpoints(ckpt_list, "model") >>> # Now load that state in the normal way: >>> _ = model.load_state_dict(averaged_state) # Suppress output >>> model.param.data tensor([8.])