要在GitHub上执行或查看/下载此笔记本
检查点
通过检查点,我们指的是在特定时间点保存模型和所有其他必要的状态信息(如优化器参数、当前周期和当前迭代)。对于实验来说,这有两个主要动机:
恢复。从实验的中间阶段继续。计算集群作业可能会耗尽时间或内存,或者可能会出现一些简单的错误,导致实验脚本在完成之前停止。在这种情况下,所有未保存到磁盘的进度都会丢失。
早停。在训练过程中,应在单独的验证集上监控性能,这可以估计泛化能力。随着训练的进行,我们期望验证误差最初会减少。然而,如果训练时间过长,验证误差可能会再次开始增加(由于过拟合)。训练结束后,我们应该回到在验证集上表现最佳的模型参数。
此外,保存训练好的模型参数也很重要,这样模型就可以在实验脚本之外使用。
SpeechBrain检查点程序的作用
SpeechBrain检查点管理器简单地协调检查点操作。它跟踪所有应包含在检查点中的内容,每个内容如何保存,检查点应存放在何处,并集中管理加载和保存。
检查点实际上并不直接将内容保存到磁盘。它要么通过类型(考虑类继承)找到合适的保存函数,要么你可以提供一个自定义的钩子。
安装依赖项
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
import speechbrain as sb
import torch
from speechbrain.utils.checkpoints import Checkpointer
SpeechBrain 检查点工具简介
多次运行以下代码块。每次运行该块时,它会训练一个epoch,然后结束。再次运行该块类似于重新启动一个实验脚本。
# You have a model, an optimizer and an epoch counter:
model = torch.nn.Linear(1, 1, False)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
epoch_counter = sb.utils.epoch_loop.EpochCounter(10)
# Create a checkpointer:
checkpoint_dir = "./nutshell_checkpoints"
checkpointer = Checkpointer(checkpoint_dir,
recoverables = {"mdl": model,
"opt": optimizer,
"epochs": epoch_counter})
# Now, before running the training epochs, you want to recover,
# if that is possible (if checkpoints have already been saved.)
# By default, the most recent checkpoint is loaded.
checkpointer.recover_if_possible()
# Then we run an epoch loop:
for epoch in epoch_counter:
print(f"Starting epoch {epoch}.")
# Training:
optimizer.zero_grad()
prediction = model(torch.tensor([1.]))
loss = (prediction - torch.tensor([1.]))**2
loss.backward()
optimizer.step()
print(f"Model prediction={prediction.item()}, loss={loss.item()}")
# And finally at the end, save an end-of-epoch checkpoint:
checkpointer.save_and_keep_only(meta={"loss":loss.item()})
# Now, let's "crash" this code block:
break
else:
# After training (epoch loop is depleted),
# we want to recover the best model:
print("Epoch loop has finished.")
checkpointer.recover_if_possible(min_key="loss")
print(f"Best model parameter: {model.weight.data}")
print(f"Achieved on epoch {epoch_counter.current}.")
# You can use this cell to reset, by deleting all checkpoints:
checkpointer.delete_checkpoints(num_to_keep=0)
检查点是什么样子的?
检查点被赋予一个顶级目录,所有检查点都存放在这里:
checkpoint_dir = "./full_example_checkpoints"
checkpointer = Checkpointer(checkpoint_dir)
每个检查点应包含许多内容,如模型参数和训练进度。
# You have a model, an optimizer and an epoch counter:
model = torch.nn.Linear(1, 1, True)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
epoch_counter = sb.utils.epoch_loop.EpochCounter(10)
每个要保存的实体都单独分配给检查点,并带有一个唯一的键,比如一个名称:
checkpointer.add_recoverable("mdl", model)
checkpointer.add_recoverables({"opt": optimizer, "epoch": epoch_counter})
当保存检查点时,检查点创建器会在顶级目录内创建一个目录。该子目录代表此保存的检查点。在新创建的目录内,传递给检查点创建器的每个实体都会获得自己的文件。
ckpt = checkpointer.save_checkpoint()
print("The checkpoint directory was:", ckpt.path)
for key, filepath in ckpt.paramfiles.items():
print("The entity with key", key, "was saved to:", filepath)
每个文件中包含什么内容?
这取决于实体。检查点通过类型(考虑类继承)找到一个保存“钩子”,并使用要保存的对象和文件路径调用该钩子。
Torch 实体(Module, Optimizer)已经有默认的保存和加载钩子:
torch_hook = sb.utils.checkpoints.get_default_hook(torch.nn.Linear(1,1), sb.utils.checkpoints.DEFAULT_SAVE_HOOKS)
print(torch_hook.__doc__)
类可以注册自己的默认保存和加载钩子:
@sb.utils.checkpoints.register_checkpoint_hooks
class Duck:
def __init__(self):
self.quacks = 0
def quack(self):
print("Quack!")
self.quacks += 1
print(f"I have already quacked {self.quacks} times.")
@sb.utils.checkpoints.mark_as_saver
def save(self, path):
with open(path, "w") as fo:
fo.write(str(self.quacks))
@sb.utils.checkpoints.mark_as_loader
def load(self, path, end_of_epoch):
# Irrelevant for ducks:
del end_of_epoch
del device
with open(path) as fi:
self.quacks = int(fi.read())
duck = Duck()
duckpointer = Checkpointer("./duckpoints", {"ducky": duck})
duckpointer.recover_if_possible()
duck.quack()
_ = duckpointer.save_checkpoint()
元信息
检查点还存储了一个元信息字典。你可以在那里放入例如验证损失或其他一些指标。默认情况下,只保存Unix时间。
# Following from the cells of "What does a checkpoint look like?"
checkpointer.save_checkpoint(meta={"loss": 15.5, "validation-type": "fast", "num-examples": 3})
ckpt = checkpointer.save_checkpoint(meta={"loss": 14.4, "validation-type": "full"})
print(ckpt.meta)
此元信息可用于加载最佳检查点,而不仅仅是最新的一个:
ckpt = checkpointer.recover_if_possible(min_key="loss")
print(ckpt.meta)
还有更多高级过滤器可用:
checkpointer.save_checkpoint(meta={"loss": 12.1, "validation-type": "fast", "num-examples": 2})
ckpt = checkpointer.recover_if_possible(importance_key=lambda ckpt: -ckpt.meta["loss"]/ckpt.meta["num-examples"],
ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "fast")
print(ckpt.meta)
保留有限数量的检查点
如今的神经模型可能非常庞大,我们不需要存储每一个检查点。检查点可以显式删除,并且可以使用与恢复相同的过滤器类型:
checkpointer.delete_checkpoints(num_to_keep=1, ckpt_predicate=lambda ckpt: "validation-type" not in ckpt.meta)
但为了方便起见,也有一个同时保存和删除的方法:
checkpointer.save_and_keep_only(meta={"loss": 13.1, "validation-type": "full"},
num_to_keep = 2,
ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "full")
预训练 / 参数转移
从预训练模型转移参数与恢复不同,尽管它们有一些相似之处。
寻找最佳检查点
参数传递的第一步是找到理想的参数集。你可以使用检查点器来实现这一点:将一个空的检查点器指向实验的顶级检查点目录,并根据你的标准找到一个检查点。
ckpt_finder = Checkpointer(checkpoint_dir)
best_ckpt = ckpt_finder.find_checkpoint(min_key="loss",
ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "full")
best_paramfile = best_ckpt.paramfiles["mdl"]
print("The best parameters were stored in:", best_paramfile)
传递参数
没有通用的参数传递公式,在许多情况下,您可能需要编写一些自定义代码来将传入的参数连接到新模型。
SpeechBrain 有一个几乎简单的实现,用于将参数转移到另一个 torch 模块,它只需加载匹配的层(按名称)并忽略未找到匹配层的保存参数:
finetune_mdl = torch.nn.Linear(1,1,False) #This one doesn't have bias!
with torch.no_grad():
print("Before:", finetune_mdl(torch.tensor([1.])))
sb.utils.checkpoints.torch_parameter_transfer(finetune_mdl, best_paramfile)
print("And after:", finetune_mdl(torch.tensor([1.])))
编排传输
SpeechBrain 有一个类似于 Checkpointer 的参数传输协调器:speechbrain.utils.parameter_transfer.Pretrainer。其主要目的是为 speechbrain.pretrained.Pretrained 子类(如 EncoderDecoderASR)实现参数下载和加载,并帮助编写易于共享的配方。
与Checkpointer类似,Pretrainer负责将参数文件映射到实例,并调用传输代码(实现为与检查点加载类似的钩子)。
引用SpeechBrain
如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}