Open In Colab 要在GitHub上执行或查看/下载此笔记本

HyperPyYAML 教程

深度学习流程中的一个关键方面是超参数和其他元数据的定义。这些超参数与深度学习算法一起,控制着流程的各个方面,包括模型架构、训练和解码。

在SpeechBrain中,我们在工具包的结构中强调超参数和学习算法之间的明确区分。为了实现这一点,我们将我们的配方分为两个主要文件:train.pytrain.yaml

train.yaml 文件遵循由 SpeechBrain 开发的格式,称为“HyperPyYAML”。我们选择扩展 YAML 是因为它在数据序列化方面具有高度可读性。通过在这个已经非常用户友好的格式基础上进行扩展,我们创建了一个超参数的扩展定义,确保我们的实验代码保持简洁且易于阅读。

这里有一个使用 PyTorch 代码的简短示例,以说明 HyperPyYAML 的使用。需要注意的是,使用 HyperPyYAML 并不需要 PyTorch:

%%capture
!pip install torch
!pip install hyperpyyaml
import torch
from hyperpyyaml import load_hyperpyyaml

example_hyperparams = """
base_channels: 32
kernel_size: 11
padding: !ref <kernel_size> // 2

layer1: !new:torch.nn.Conv1d
  in_channels: 1
  out_channels: !ref <base_channels>
  kernel_size: !ref <kernel_size>
  padding: !ref <padding>

layer2: !new:torch.nn.Conv1d
  in_channels: !ref <base_channels>
  out_channels: !ref <base_channels> * 2
  kernel_size: !ref <kernel_size>
  padding: !ref <padding>

layer3: !new:torch.nn.Conv1d
  in_channels: !ref <base_channels> * 2
  out_channels: 1
  kernel_size: !ref <kernel_size>
  padding: !ref <padding>

model: !new:torch.nn.Sequential
  - !ref <layer1>
  - !new:torch.nn.LeakyReLU
  - !ref <layer2>
  - !new:torch.nn.LeakyReLU
  - !ref <layer3>
"""

# Create model directly by loading the YAML
loaded_hparams = load_hyperpyyaml(example_hyperparams)
model = loaded_hparams["model"]

# Transform a 2-second audio clip
input_audio = torch.rand(1, 1, 32000)
transformed_audio = model(input_audio)
print(transformed_audio.shape)

# Try a different hyperparameter value by overriding the padding value
loaded_hparams = load_hyperpyyaml(example_hyperparams, {"padding": 0})
model = loaded_hparams["model"]
transformed_audio = model(input_audio)
print(transformed_audio.shape)

如本例所示,HyperPyYAML 允许使用组合进行复杂的超参数定义。此外,任何值都可以被覆盖以进行超参数调优。为了理解这一切是如何工作的,让我们首先简要了解一下 YAML 的基础知识。

基本YAML语法

前奏够了:让我们来谈谈YAML!这里有一个简短的YAML片段示例,以及它在加载到Python后的样子:

import yaml
yaml_string = """
foo: 1
bar:
  - item1
  - item2
baz:
  item1: 3.4
  item2: True
"""
yaml.safe_load(yaml_string)

如你所见,YAML内置支持多种数据类型,包括字符串、整数、浮点数、布尔值、列表和字典。我们的HyperPyYAML格式保留了所有这些功能。

from hyperpyyaml import load_hyperpyyaml
load_hyperpyyaml(yaml_string)

我们对yaml格式的主要添加是通过YAML标签实现的。标签在项目定义之前添加,并以!为前缀。为了说明标签的使用方式,这里有一个我们做的小添加的示例,即!tuple标签:

yaml_string = """
foo: !tuple (3, 4)
"""
load_hyperpyyaml(yaml_string)

现在你已经了解了YAML的基础知识,是时候继续学习我们的附加内容了!

标签 !new:!name:

YAML标签可以包含一个后缀,以更具体地定义它是什么类型的标签。我们使用这个来定义一个能够创建任何Python对象而不仅仅是基本类型的标签。这个标签以!new:开头,并包含对象的类型。例如:

yaml_string = """
foo: !new:collections.Counter
"""
loaded_yaml = load_hyperpyyaml(yaml_string)
loaded_yaml["foo"]
loaded_yaml["foo"].update({"a": 3, "b": 5})
loaded_yaml["foo"]["a"] += 1
loaded_yaml["foo"]

当然,许多 Python 对象在创建时接受参数。这些参数可以通过列表传递位置参数,或通过字典传递关键字参数。

yaml_string = """
foo: !new:collections.Counter
  - [a, b, r, a, c, a, d, a, b, r, a]
bar: !new:collections.Counter
  a: 2
  b: 1
  c: 5
"""
load_hyperpyyaml(yaml_string)

另一个有用的Python对象是函数对象。在HyperPyYAML中,可以使用!name:标签来创建。在幕后,这个标签使用functools.partial来创建一个带有默认参数的新函数定义。例如:

yaml_string = """
foo: !name:collections.Counter
  a: 2
"""
loaded_yaml = load_hyperpyyaml(yaml_string)
loaded_yaml["foo"](b=4)

默认参数可以被覆盖,就像普通的python函数一样

loaded_yaml["foo"](a=3, b=5)

标签 !ref!copy

当然,一些超参数在多个地方使用,所以我们添加了一种引用另一个项目的机制,称为!ref。应用此标签的节点必须是一个包含要复制的节点位置的字符串。子节点可以用方括号访问,与Python中相同。例如:

yaml_string = """
foo:
  a: 3
  b: 4
bar:
  c: !ref <foo>
  d: !ref <foo[b]>
"""
load_hyperpyyaml(yaml_string)

!ref 标签可以支持简单的算术和字符串连接,用于基本的超参数组合。

yaml_string = """
folder1: abc/def
folder2: ghi/jkl
folder3: !ref <folder1>/<folder2>

foo: 1024
bar: 512
baz: !ref <foo> // <bar> + 1
"""
load_hyperpyyaml(yaml_string)

!ref 标签也可以引用对象,在这种情况下,它引用的是同一个对象,而不是副本。如果您希望创建一个副本,请使用 !copy 标签。

yaml_string = """
foo: !new:collections.Counter
  a: 4
bar: !ref <foo>
baz: !copy <foo>
"""
loaded_yaml = load_hyperpyyaml(yaml_string)
loaded_yaml["foo"].update({"b": 10})
print(loaded_yaml["bar"])
print(loaded_yaml["baz"])

其他标签

我们还介绍了各种其他标签:

  • !tuple 用于创建 Python 元组。请注意,这是隐式解析的,因此您不需要显式写出元组标签,只需像在 Python 中一样使用括号即可。

  • !include 直接插入其他 yaml 文件

  • !apply 用于加载并执行一个python函数,存储结果

我们使用!apply在加载yaml文件时设置随机种子,以便每次运行时模型具有相同的参数。结果不会被存储,因为它以__开头。

yaml_string = """
sum: !apply:sum
  - [1, 2]
__set_seed: !apply:torch.manual_seed [1234]
"""
load_hyperpyyaml(yaml_string)

覆盖

为了使用超参数的各种值运行实验,我们有一个系统可以覆盖yaml文件中列出的值。

overrides = {"foo": 7}
fake_file = """
foo: 2
bar: 5
"""
load_hyperpyyaml(fake_file, overrides)

如本例所示,覆盖可以使用普通的python字典。但是,这种形式不支持python对象。要覆盖python对象,覆盖也可以使用带有HyperPyYAML语法的yaml格式字符串。

load_hyperpyyaml(fake_file, "foo: !new:collections.Counter")

结论

我们自豪地展示我们的HyperPyYAML语法,我们认为它提供了一种可读且简洁的方式来构建超参数定义。此外,它从实验文件中移除了不必要的复杂性,使算法变得清晰。正如第一个示例所示,覆盖非常容易,使得超参数调优变得轻而易举。总的来说,我们发现这个包是深度学习的宝贵工具!

引用SpeechBrain

如果您在研究中或业务中使用SpeechBrain,请使用以下BibTeX条目引用它:

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}