要在GitHub上执行或查看/下载此笔记本
超参数优化
作为SpeechBrain项目的一部分实现的许多语音处理任务依赖于超参数的仔细选择,例如:
层数
归一化
隐藏层维度
成本函数中的权重
等等
手动选择这样的超参数可能会很繁琐。本教程将展示如何使用Oríon项目中实现的自动超参数优化技术,以系统的方式自动拟合超参数。
先决条件
导入
import os
安装 SpeechBrain
SpeechBrain 可以从下面列出的 GitHub 仓库下载。
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
依赖修复
PyYAML 6.0 不向后兼容,需要 5.x 版本以支持 HyperPyYAML
%%capture
!pip install pyyaml==5.4.1
安装Oríon
Oríon 可以使用 pip
或 conda
安装
%%capture
!pip install orion[profet]
from speechbrain.utils import hpopt as hp
更新配方以支持超参数优化
SpeechBrain 提供了一个名为 hpopt
的便捷包装器,它能够将目标值报告给 Orion 或其他工具。
有关如何实现它的完整示例,
将以下导入语句添加到您的配方顶部:
from speechbrain.utils import hpopt as hp
将你的配方的主要代码包装在超参数优化上下文中。将
objective_key
设置为Orion将优化的指标。之前:
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
## ...
spk_id_brain = SpkIdBrain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# The `fit()` method iterates the training loop, calling the methods
# necessary to update the parameters of the model. Since all objects
# with changing state are managed by the Checkpointer, training can be
# stopped at any point, and will be resumed on next call.
spk_id_brain.fit(
epoch_counter=spk_id_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
之后:
```python
with hp.hyperparameter_optimization(objective_key="error") as hp_ctx: # <-- Initialize the context
hparams_file, run_opts, overrides = hp_ctx.parse_arguments(sys.argv[1:]) # <-- Replace sb with hp_ctx
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
## ...
spk_id_brain = SpkIdBrain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# The `fit()` method iterates the training loop, calling the methods
# necessary to update the parameters of the model. Since all objects
# with changing state are managed by the Checkpointer, training can be
# stopped at any point, and will be resumed on next call.
spk_id_brain.fit(
epoch_counter=spk_id_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
```
添加代码以报告统计信息
例如,在on_stage_end
中,当stage == sb.Stage.VALID
时
hp.report_result(stage_stats)
通过此函数报告的最后结果将用于超参数优化。
在objective_key参数中指定的键需要存在于传递给report_result
的字典中。
在你的主要超参数文件
train.yaml
中添加以下行:
hpopt_mode: null
hpopt: null
可选:创建一个单独的YAML文件,覆盖在超参数优化期间使用的任何超参数,这些参数与常规训练期间使用的参数不同,除了正在拟合的参数。典型的方法会减少训练周期和训练样本的数量。
如果被覆盖的参数数量较少,则可以省略此步骤。在这种情况下,它们可以通过命令行传递。
示例:
hpopt.yaml
:
number_of_epochs: 1
ckpt_enable: false
❗ 重要: 大多数配方使用检查点来在每个周期(或自定义时间表)后保存模型的快照,以确保如果训练中断可以恢复。在超参数优化期间,这可能会导致问题,因为如果模型的架构(例如层数、每层神经元数等)从一组超参数值变为下一组,尝试恢复检查点将失败。
一个可能的解决方案是使检查点的运行有条件,并在hpopt.yaml
中禁用它
之前:
self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error"])
之后:
if self.hparams.ckpt_enable:
self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error"])
另一种策略是重新配置检查点,将每次运行保存在单独的目录中。对于这种情况,超参数优化包装器可以提供一个名为trial_id的变量,该变量可以插入到输出路径中。
下面给出了这个策略的一个示例:
hpopt.yaml
:
number_of_epochs: 1
ckpt_enable: False
trial_id: hpopt
output_folder: !ref ./results/speaker_id/<trial_id>
train.yaml
:
# ...
save_folder: !ref <output_folder>/save
# ...
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder> #<-- will contain trial_id
recoverables:
embedding_model: !ref <embedding_model>
classifier: !ref <classifier>
normalizer: !ref <mean_var_norm>
counter: !ref <epoch_counter>
执行超参数搜索
选择和准备超参数
从您的超参数文件中选择您希望使用Orion进行优化的超参数。这些超参数需要在顶层可用,以便使用此技术进行拟合。
考虑以下示例文件:
dropout: 0.1
n_mels: 80
encoder: !new:speechbrain.lobes.models.mymodel.MyModel
input_shape: [null, null, !ref <n_mels>]
dropout: !ref <dropout>
cnn_blocks: 3
在上述文件中,n_mels
和 dropout
可用于优化,但 cnn_blocks
不可用。
为了使cnn_blocks
可用于优化,请按以下方式修改:
dropout: 0.1
n_mels: 80
cnn_blocks: 3 # <-- Define at the top level
encoder: !new:speechbrain.lobes.models.mymodel.MyModel
input_shape: [null, null, !ref <n_mels>]
dropout: !ref <dropout>
cnn_blocks: !ref <cnn_blocks> # <-- Introduce a reference
配置Orion
创建一个.yaml
文件,其中包含要使用的Orion算法的配置。
下面给出一个示例:
experiment:
max_trials: 1000
max_broken: 1000
algorithms:
tpe:
seed: 42
n_initial_points: 5
config_file_content = """
experiment:
max_trials: 3
max_broken: 1
algorithms:
tpe:
seed: 42
n_initial_points: 5
"""
config_path = os.path.expanduser("~/config")
if not os.path.exists(config_path):
os.mkdir(config_path)
config_file_path = os.path.join(config_path, "orion-speaker-id.yaml")
with open(config_file_path, "w") as config_file:
print(config_file_content, file=config_file)
有关可用算法的更多信息,请查看Orion Repository。
定义搜索空间
编写一个调用Orion定义搜索空间的shell脚本
示例:
#!/bin/bash
HPOPT_EXPERIMENT_NAME=speaker-id
HPOPT_CONFIG_FILE=$HOME/config/orion-speaker-id.yaml
orion hunt -n $HPOPT_EXPERIMENT_NAME -c $HPOPT_CONFIG_FILE python train.py hparams/$HPARAMS \
--hpopt hpopt.yaml \
--hpopt_mode orion \
--emb_dim~"choices([128,256,512,768,1024])" \
--tdnn_channels~"choices([128,256,512,768,1024])"
如果不使用额外的hpopt.yaml
文件,请将--hpopt hpopt.yaml
替换为--hpopt=True
。
考虑运行下面的独立示例
%env PYTHONPATH=/env/python:/content/speechbrain/
%cd /content/speechbrain/templates/hyperparameter_optimization_speaker_id
!orion hunt -n speaker-id -c $HOME/config/orion-speaker-id.yaml python train.py train.yaml \
--hpopt hpopt.yaml \
--hpopt_mode orion \
--emb_dim~"choices([128,256,512,768,1024])" \
--tdnn_channels~"choices([128,256,512,768,1024])"
检查结果
使用orion info
命令来检查超参数拟合的结果。
该工具将输出关于超参数拟合实验的基本统计信息,包括完成的运行次数、最佳试验的目标值以及与该运行对应的超参数值。
在下面的示例中,最佳目标实现值显示在evaluation下,相应的超参数值显示在params下。
Stats
=====
completed: False
trials completed: 4
best trial:
id: c1a71e0988d70005302ab655d7e391d3
evaluation: 0.2384105920791626
params:
/emb_dim: 128
/tdnn_channels: 128
start time: 2021-11-14 21:01:12.760704
finish time: 2021-11-14 21:13:25.043336
duration: 0:12:12.282632
!orion info --name speaker-id
大规模超参数优化
多GPU
由于Orion只是简单地包装了训练脚本的执行,并使用操作系统外壳为每组超参数启动它,因此支持数据并行(DP)或分布式数据并行(DDP)执行的训练脚本可以在不修改的情况下用于超参数拟合。
有关如何设置DP/DDP实验的信息,请参阅SpeechBrain文档和多GPU注意事项教程。
并行或分布式Oríon
Oríon 本身提供了对并行和分布式超参数拟合的支持。
要在单个节点上使用多个并行工作器,请将--n-workers
参数传递给Oríon CLI。
下面的示例将启动一个有三个工作者的实验:
orion hunt -n $HPOPT_EXPERIMENT_NAME -c $HOPT_CONFIG_FILE --n-workers 3 python train.py hparams/$HPARAMS \
--hpopt hpopt.yaml \
--hpopt_mode orion \
--emb_dim~"choices([128,256,512,768,1024])" \
--tdnn_channels~"choices([128,256,512,768,1024])"
对于更高级的场景,包括在多个节点上进行分布式超参数拟合,请参阅Oríon官方文档中的Parallel Workers页面。
引用SpeechBrain
如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}