要在GitHub上执行或查看/下载此笔记本
大脑类
深度学习的一个基本方面涉及多次遍历数据集并更新模型参数,通常称为“训练循环”。为了简化和组织这一过程,SpeechBrain 提供了一个多功能的框架,即“Brain”类,在 speechbrain/core.py 中实现。在每个配方中,这个类被继承,并且其方法被重写,以适应该配方的特定需求。
Brain类的核心方法是fit(),负责遍历数据集,执行模型更新,并管理训练循环。要利用fit(),必须在子类中定义至少两个方法:compute_forward()和compute_objectives()。这些方法处理模型的计算以生成预测,并计算梯度计算所需的损失项。
让我们通过一个简单的例子来说明这一点:
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
# Clone SpeechBrain repository
!git clone https://github.com/speechbrain/speechbrain/
import torch
import speechbrain as sb
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
return torch.nn.functional.l1_loss(predictions, batch["target"])
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1))
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]
brain.fit(range(10), data)
只需大约10行代码,我们就可以成功训练一个神经网络模型。这种效率之所以能够实现,是因为Brain类处理了训练中的复杂细节,例如管理train()和eval()状态,或者计算和应用梯度。此外,该类的灵活性允许通过向子类添加方法来覆盖过程中的每一步。这意味着即使是复杂的训练过程,例如生成对抗网络(GAN)中的那些,也可以无缝集成到Brain类中。
在本教程中,我们将首先解释Brain类的参数。随后,我们将深入探讨fit()方法,逐步分解并突出显示在必要时可以覆盖的部分。这些关于类参数和fit()方法的见解构成了理解Brain类功能和多样性的基础。
Brain 类的参数
Brain 类只接受 5 个参数,但每个参数可能有点复杂,因此我们在这里详细解释它们。相关代码只是 __init__ 定义:
def __init__(
self,
modules=None,
opt_class=None,
hparams=None,
run_opts=None,
checkpointer=None,
):
modules 参数
第一个参数接受一个包含torch模块的字典。Brain类接收这个字典并将其转换为Torch ModuleDict。这提供了一种方便的方式将所有参数移动到正确的设备,调用train()和eval(),并在必要时将模块包装在适当的分布式包装器中。
opt_class 参数
Brain 类接受一个用于 PyTorch 优化器的函数定义。选择将其作为输入而不是预先构建的 PyTorch 优化器的原因是,如果需要,Brain 类会自动处理将模块参数包装在分布式包装器中。这需要在参数传递给优化器构造函数之前完成。
要传递一个pytorch优化器构造函数,可以使用lambda,如本教程开头的示例所示。然而,更方便的是SpeechBrain中大多数配方使用的选项:使用HyperPyYAML定义构造函数。!name:标签的作用类似于lambda,创建一个新的构造函数,可用于制作优化器。
optimizer: !name:torch.optim.Adam
lr: 0.1
当然,有时需要零个或多个优化器。在多个优化器的情况下,可以重写init_optimizers方法来分别初始化每个优化器。
hparams 参数
Brain类算法可能依赖于一组应该易于外部控制的超参数,此参数接受一个字典,该字典将使用“点表示法”对所有内部方法可访问。示例如下:
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
term1 = torch.nn.functional.l1_loss(predictions, batch["target1"])
term2 = torch.nn.functional.mse_loss(predictions, batch["target2"])
return self.hparams.weight1 * term1 + self.hparams.weight2 * term2
hparams = {"weight1": 0.7, "weight2": 0.3}
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
modules={"model": model},
opt_class=lambda x: torch.optim.SGD(x, 0.1),
hparams=hparams,
)
data = [{
"input": torch.rand(10, 10),
"target1": torch.rand(10, 10),
"target2": torch.rand(10, 10),
}]
brain.fit(range(10), data)
run_opts 参数
有许多选项可以控制fit()方法的执行细节,所有这些选项都可以通过此参数传递。一些示例包括启用调试模式、执行设备和分布式执行选项。有关完整列表,请参阅[添加文档链接]。
checkpointer 参数
最后,如果你将一个SpeechBrain检查点传递给Brain类,有几个操作会自动被调用:
优化器参数被添加到检查点器中。
在训练开始时,会加载最近的检查点,并从该点恢复训练。如果训练已完成,则直接结束训练步骤并进入评估阶段。
在训练过程中,默认每15分钟保存一次检查点(可以通过
run_opts中的选项更改或禁用此设置)。在评估开始时,将加载“最佳”检查点,这是根据检查点中记录的指标的最低或最高分数确定的。
fit() 方法
这个方法做了很多,但实际上只用了大约100行代码,所以通过阅读代码本身是可以理解的。我们逐节分解并解释每一部分的作用。首先,让我们简要地看一下参数:
def fit(
self,
epoch_counter,
train_set,
valid_set=None,
progressbar=None,
train_loader_kwargs={},
valid_loader_kwargs={},
):
epoch_counter参数接受一个迭代器,因此当调用fit()时,外部循环会迭代此变量。此参数与EpochCounter类共同设计,使得可以存储 epoch 循环状态。通过此参数,我们可以从上次停止的地方重新开始实验。train_set和valid_set参数接受一个 Torch Dataset 或 DataLoader,它们将加载训练所需的张量。如果没有传递 DataLoader,则会自动构建一个(参见下一节)。progressbar参数控制是否显示一个tqdm进度条,以显示每个 epoch 在数据集中的进度。train_loader_kwargs和valid_loader_kwargs被传递给make_dataloader方法以创建 DataLoader(参见下一节)。
拟合结构
在解决了参数问题之后,我们可以开始查看这个方法的结构。这里有一个简单的图表来展示fit()中所有可重写的调用。我们将在教程的其余部分逐一介绍这些内容。
make_dataloader
fit() 方法的第一步是确保数据格式适合迭代。train_set 和 valid_set 都与它们各自的关键字参数一起传递。以下是实际代码:
if not isinstance(train_set, DataLoader):
train_set = self.make_dataloader(
train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
)
if valid_set is not None and not isinstance(valid_set, DataLoader):
valid_set = self.make_dataloader(
valid_set,
stage=sb.Stage.VALID,
ckpt_prefix=None,
**valid_loader_kwargs,
)
默认情况下,此方法处理DataLoader创建时可能出现的复杂情况,例如为分布式执行创建DistributedSampler。与fit()调用中的所有其他方法一样,可以通过在Brain的子类定义中创建make_dataloader方法来覆盖此行为。
on_fit_start
除了数据加载器之外,在训练开始之前还需要进行一些设置。以下是相关代码:
self.on_fit_start()
if progressbar is None:
progressbar = self.progressbar
on_fit_start 方法负责处理一些重要的事情,通过分享代码可以最容易地解释这些内容:
def on_fit_start(self):
self._compile_jit()
self._wrap_distributed()
self.init_optimizers()
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
基本上,此方法确保torch模块得到适当的准备,包括jit编译、分布式包装以及使用所有相关参数初始化优化器。如果有检查点,优化器初始化还会将优化器参数添加到检查点中。最后,此方法加载最新的检查点,以便在训练中断时恢复训练。
on_stage_start
下一部分开始迭代周期并准备迭代训练数据。要调整准备工作,可以重写on_stage_start方法,这将允许执行诸如创建容器以存储训练统计信息等操作。
for epoch in epoch_counter:
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
self.nonfinite_count = 0
if self.train_sampler is not None and hasattr(
self.train_sampler, "set_epoch"
):
self.train_sampler.set_epoch(epoch)
last_ckpt_time = time.time()
训练循环
本教程中最长的代码块致力于训练和验证数据循环。然而,它们实际上只做了三件重要的事情:
在DataLoader中的每个批次上调用
fit_batch()。跟踪平均损失并报告。
可选地定期保存检查点,以便可以恢复训练。
这是代码:
enable = progressbar and sb.utils.distributed.if_main_process()
with tqdm(
train_set, initial=self.step, dynamic_ncols=True, disable=not enable,
) as t:
for batch in t:
self.step += 1
loss = self.fit_batch(batch)
self.avg_train_loss = self.update_average(
loss, self.avg_train_loss
)
t.set_postfix(train_loss=self.avg_train_loss)
if self.debug and self.step == self.debug_batches:
break
if (
self.checkpointer is not None
and self.ckpt_interval_minutes > 0
and time.time() - last_ckpt_time
>= self.ckpt_interval_minutes * 60.0
):
run_on_main(self._save_intra_epoch_ckpt)
last_ckpt_time = time.time()
也许最重要的步骤是fit_batch(batch)调用,我们在这里展示了一个简化版本:
def fit_batch(self, batch):
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
return loss.detach().cpu()
此方法调用最重要的拟合方法,compute_forward 和 compute_objectives,这两个方法都必须被重写以便使用 Brain 类。然后损失被反向传播,并在应用更新之前检查梯度是否存在非有限值和过大的范数(默认情况下,过大的范数会自动裁剪)。
on_stage_end
在训练循环结束时,调用on_stage_end方法进行潜在的清理操作,例如报告训练统计信息。
self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0
验证循环
与训练循环非常相似,验证循环也会遍历数据加载器并一次处理一批数据。然而,这个循环不会调用fit_batch,而是调用evaluate_batch,它不会反向传播梯度或应用任何更新。
if valid_set is not None:
self.on_stage_start(Stage.VALID, epoch)
self.modules.eval()
avg_valid_loss = 0.0
with torch.no_grad():
for batch in tqdm(
valid_set, dynamic_ncols=True, disable=not enable
):
self.step += 1
loss = self.evaluate_batch(batch, stage=Stage.VALID)
avg_valid_loss = self.update_average(
loss, avg_valid_loss
)
if self.debug and self.step == self.debug_batches:
break
on_stage_end
此方法与训练阶段的方法相同,但这次仅在单个进程上执行,因为通常该进程会涉及写入文件。常见用途包括:更新学习率、保存检查点以及记录一个周期的统计信息。
self.step = 0
run_on_main(
self.on_stage_end,
args=[Stage.VALID, avg_valid_loss, epoch],
)
最后一件非常简单的事情是检查调试模式,以便只运行几个周期。
if self.debug and epoch == self.debug_epochs:
break
恭喜,你现在知道了fit()方法的工作原理,以及为什么它是一个用于运行实验的有用工具。训练模型的所有部分都被分解,烦人的部分得到了处理,同时通过覆盖Brain类的任何部分仍然可以保持完全的灵活性。
evaluate() 方法
此方法以与fit()方法的验证数据大致相同的方式迭代测试数据,包括调用on_stage_start和on_stage_end。另一个被调用的方法是on_evaluate_start()方法,默认情况下会加载最佳检查点进行评估。
结论
Brain 类和特别是 fit() 方法受到了其他流行的 Python 统计和机器学习库的启发,特别是 numpy、scipy、keras 和 PyTorch Lightning。
随着我们添加关于Brain类更高级用法的教程,我们将在此处添加它们的链接。一些计划中的教程示例:
使用Brain类编写GAN
使用Brain类进行分布式训练
基于非梯度的Brain类使用
引用SpeechBrain
如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}