要在GitHub上执行或查看/下载此笔记本
HyperPyYAML 教程
深度学习流程中的一个关键方面是超参数和其他元数据的定义。这些超参数与深度学习算法一起,控制着流程的各个方面,包括模型架构、训练和解码。
在SpeechBrain中,我们在工具包的结构中强调超参数和学习算法之间的明确区分。为了实现这一点,我们将我们的配方分为两个主要文件:train.py 和 train.yaml。
train.yaml 文件遵循由 SpeechBrain 开发的格式,称为“HyperPyYAML”。我们选择扩展 YAML 是因为它在数据序列化方面具有高度可读性。通过在这个已经非常用户友好的格式基础上进行扩展,我们创建了一个超参数的扩展定义,确保我们的实验代码保持简洁且易于阅读。
这里有一个使用 PyTorch 代码的简短示例,以说明 HyperPyYAML 的使用。需要注意的是,使用 HyperPyYAML 并不需要 PyTorch:
%%capture
!pip install torch
!pip install hyperpyyaml
import torch
from hyperpyyaml import load_hyperpyyaml
example_hyperparams = """
base_channels: 32
kernel_size: 11
padding: !ref <kernel_size> // 2
layer1: !new:torch.nn.Conv1d
in_channels: 1
out_channels: !ref <base_channels>
kernel_size: !ref <kernel_size>
padding: !ref <padding>
layer2: !new:torch.nn.Conv1d
in_channels: !ref <base_channels>
out_channels: !ref <base_channels> * 2
kernel_size: !ref <kernel_size>
padding: !ref <padding>
layer3: !new:torch.nn.Conv1d
in_channels: !ref <base_channels> * 2
out_channels: 1
kernel_size: !ref <kernel_size>
padding: !ref <padding>
model: !new:torch.nn.Sequential
- !ref <layer1>
- !new:torch.nn.LeakyReLU
- !ref <layer2>
- !new:torch.nn.LeakyReLU
- !ref <layer3>
"""
# Create model directly by loading the YAML
loaded_hparams = load_hyperpyyaml(example_hyperparams)
model = loaded_hparams["model"]
# Transform a 2-second audio clip
input_audio = torch.rand(1, 1, 32000)
transformed_audio = model(input_audio)
print(transformed_audio.shape)
# Try a different hyperparameter value by overriding the padding value
loaded_hparams = load_hyperpyyaml(example_hyperparams, {"padding": 0})
model = loaded_hparams["model"]
transformed_audio = model(input_audio)
print(transformed_audio.shape)
如本例所示,HyperPyYAML 允许使用组合进行复杂的超参数定义。此外,任何值都可以被覆盖以进行超参数调优。为了理解这一切是如何工作的,让我们首先简要了解一下 YAML 的基础知识。
基本YAML语法
前奏够了:让我们来谈谈YAML!这里有一个简短的YAML片段示例,以及它在加载到Python后的样子:
import yaml
yaml_string = """
foo: 1
bar:
- item1
- item2
baz:
item1: 3.4
item2: True
"""
yaml.safe_load(yaml_string)
如你所见,YAML内置支持多种数据类型,包括字符串、整数、浮点数、布尔值、列表和字典。我们的HyperPyYAML格式保留了所有这些功能。
from hyperpyyaml import load_hyperpyyaml
load_hyperpyyaml(yaml_string)
我们对yaml格式的主要添加是通过YAML标签实现的。标签在项目定义之前添加,并以!为前缀。为了说明标签的使用方式,这里有一个我们做的小添加的示例,即!tuple标签:
yaml_string = """
foo: !tuple (3, 4)
"""
load_hyperpyyaml(yaml_string)
现在你已经了解了YAML的基础知识,是时候继续学习我们的附加内容了!
覆盖
为了使用超参数的各种值运行实验,我们有一个系统可以覆盖yaml文件中列出的值。
overrides = {"foo": 7}
fake_file = """
foo: 2
bar: 5
"""
load_hyperpyyaml(fake_file, overrides)
如本例所示,覆盖可以使用普通的python字典。但是,这种形式不支持python对象。要覆盖python对象,覆盖也可以使用带有HyperPyYAML语法的yaml格式字符串。
load_hyperpyyaml(fake_file, "foo: !new:collections.Counter")
结论
我们自豪地展示我们的HyperPyYAML语法,我们认为它提供了一种可读且简洁的方式来构建超参数定义。此外,它从实验文件中移除了不必要的复杂性,使算法变得清晰。正如第一个示例所示,覆盖非常容易,使得超参数调优变得轻而易举。总的来说,我们发现这个包是深度学习的宝贵工具!
引用SpeechBrain
如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}