speechbrain.core 模块
用于运行实验的核心SpeechBrain代码。
- Authors
彼得·普兰廷加 2020, 2023
阿卜杜勒·赫巴 2020
Mirco Ravanelli 2020
阿库·罗赫 2021
安德烈亚斯·诺奇 2022
Sylvain de Langen 2023
阿德尔·穆门 2023, 2024
摘要
类:
自动混合精度(AMP)的配置。 |
|
Brain 类抽象了数据循环的细节。 |
|
简单的枚举,用于跟踪实验的阶段。 |
函数:
创建输出文件夹和相关实验文件。 |
|
解析实验的命令行参数。 |
参考
- class speechbrain.core.AMPConfig(dtype: dtype)[source]
基础类:
object自动混合精度(AMP)的配置。
- Parameters:
dtype (torch.dtype) – 用于AMP的数据类型。
- dtype: dtype
- speechbrain.core.create_experiment_directory(experiment_directory, hyperparams_to_save=None, overrides={}, log_config='/home/docs/checkouts/readthedocs.org/user_builds/speechbrain/checkouts/latest/speechbrain/log-config.yaml', save_env_desc=True)[source]
创建输出文件夹和相关实验文件。
- speechbrain.core.parse_arguments(arg_list=None)[source]
解析实验的命令行参数。
- Parameters:
arg_list (list, None) – 要解析的参数列表。如果未提供,则从
sys.argv[1:]中读取。- Returns:
param_file (str) – 参数文件的位置。
run_opts (dict) – 运行选项,例如分布式、设备等。
overrides (dict) – 传递给
load_hyperpyyaml的覆盖项。
Example
>>> argv = ['hyperparams.yaml', '--device', 'cuda:1', '--seed', '10'] >>> filename, run_opts, overrides = parse_arguments(argv) >>> filename 'hyperparams.yaml' >>> run_opts["device"] 'cuda:1' >>> overrides 'seed: 10'
- class speechbrain.core.Stage(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]
基础类:
Enum用于跟踪实验阶段的简单枚举。
- TRAIN = 1
- VALID = 2
- TEST = 3
- class speechbrain.core.Brain(modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None)[source]
基础类:
objectBrain 类抽象了数据循环的细节。
Brain类的主要目的是实现fit()方法,该方法迭代 epochs 和数据集,目的是将一组模块“拟合”到一组数据上。为了使用
fit()方法,应该继承Brain类,并重写默认行为与用例不匹配的任何方法。对于简单的用例(例如,使用单个数据集训练单个模型),需要重写的方法只有:compute_forward()compute_objectives()
下面的示例说明了如何重写这两个方法。
对于更复杂的用例,例如需要更新的多个模块,可以重写以下方法:
fit_batch()evaluate_batch()
- Parameters:
模块 (字典的str:torch.nn.Module 对) – 如果这些模块具有可训练的参数,默认情况下它们会被传递给优化器,并且会在它们上调用
train()/eval()。opt_class (torch.optim 类) – 一个仅接受参数列表的 torch 优化器构造函数(例如 lambda 或部分函数定义)。默认情况下,在
fit()方法开始时,这将传递modules中的所有模块。可以通过重写configure_optimizers()方法来更改此行为。hparams (dict) – 每个键值对应由一个字符串键和一个在重写方法中使用的超参数组成。这些可以通过
hparams属性访问,使用“点”表示法:例如,self.hparams.model(x)。run_opts –
A set of options to change the runtime environment, including
- debug (bool)
If
True, this will only iterate a few batches for all datasets, to ensure code runs without crashing.- debug_batches (int)
Number of batches to run in debug mode, Default
2.- debug_epochs (int)
Number of epochs to run in debug mode, Default
2. If a non-positive number is passed, all epochs are run.- debug_persistently (bool)
Keep data stored during debug mode (not using /tmp), Default
False.- jit (bool)
Enable to compile all modules using jit, Default
False.- jit_module_keys (list of str)
List of keys in
modulesthat should be jit compiled.- compile (bool)
Enable to compile all modules using torch.compile, Default
False.- compile_module_keys (list of str)
List of keys in
modulesthat should be compiled usingtorch.compile. Iftorch.compileis unavailable, an error is raised.- compile_mode (str)
One of
default,reduce-overhead,max-autotune, Defaultreduce-overhead.- compile_using_fullgraph (bool)
Whether it is ok to break model into several subgraphs, Default
False.- compile_using_dynamic_shape_tracing (bool)
Use dynamic shape tracing for compilation, Default
False.- distributed_backend (str)
One of
nccl,gloo,mpi.- device (str)
The location for performing computations.
- precision (str)
One of
fp32,fp16,bf16.- eval_precision (str)
One of
fp32,fp16,bf16.- auto_mix_prec (bool)
If
True, automatic mixed-precision (fp16) is used. Activate it only with cuda. Note: this is a deprecated feature, and will be removed in the future.- bfloat16_mix_prec (bool)
If
True, automatic mixed-precision (bf16) is used. Activate it only with cuda. Note: this is a deprecated feature, and will be removed in the future.- max_grad_norm (float)
Default implementation of
fit_batch()usesclip_grad_norm_with this value. Default:5.- skip_nonfinite_grads (bool)
If
True, sets gradients to zero if they are non-finite (e.g., NaN, Inf). Default:False.- nonfinite_patience (int)
Number of times to ignore non-finite losses before stopping. Default:
3.- noprogressbar (bool)
Whether to turn off progressbar when training. Default:
False.- ckpt_interval_minutes (float)
Amount of time between saving intra-epoch checkpoints, in minutes, default:
15.0. If non-positive, these are not saved.- ckpt_interval_steps (int)
Number of steps between saving intra-epoch checkpoints. If non-positive, these are not saved. Default:
0.
- checkpointerspeechbrain.Checkpointer
默认情况下,这将用于加载检查点,并且如果训练中断,将添加优化器以继续训练。
Example
>>> from torch.optim import SGD >>> class SimpleBrain(Brain): ... def compute_forward(self, batch, stage): ... return self.modules.model(batch[0]) ... def compute_objectives(self, predictions, batch, stage): ... return torch.nn.functional.l1_loss(predictions, batch[0]) >>> model = torch.nn.Linear(in_features=10, out_features=10) >>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1)) >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
- compute_forward(batch, stage)[source]
前向传播,将由子类覆盖。
- Parameters:
batch (torch.Tensor 或 tensors) – 数据加载器中的一个元素,包括用于处理的输入。
stage (Stage) – 实验的阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
所有处理完成后的输出。 直接传递给
compute_objectives()。- Return type:
torch.Tensor 或 torch.Tensors
- compute_objectives(predictions, batch, stage)[source]
计算损失,将由子类覆盖。
- Parameters:
predictions (torch.Tensor 或 torch.Tensors) – 要评估的输出张量或张量。 直接来自
compute_forward()。batch (torch.Tensor 或 tensors) – 数据加载器中的一个元素,包括用于比较的目标。
stage (Stage) – 实验的阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
loss – 一个包含计算损失的张量。
- Return type:
torch.Tensor
- make_dataloader(dataset, stage, ckpt_prefix='dataloader-', **loader_kwargs)[source]
为数据集创建数据加载器。
这被
fit()和evaluate()使用,如果它们只接收数据集。或者,这可以从 Brain 子类外部调用。 在这种情况下,应将 DataLoader 传递给
fit()以代替数据集。Stage.TRAIN DataLoader 会被特殊处理。它有额外的参数用于 shuffle 和 drop_last。在 DDP 中,会创建一个 DistributedSampler(除非数据集是 IterableDataset)。
注意
一些重要的DataLoader参数通过**loader_kwargs传递,例如batch_size、num_workers、pin_memory。
注意
默认情况下,
evaluate()指定 ckpt_prefix=None 以阻止测试 DataLoader 被添加到检查点器中。如果你需要在保存检查点后添加一个 可恢复的(例如,在测试时,在 检查点训练之后),并且仍然能够合理地恢复, 你可能应该指定allow_partial_load=True。- Parameters:
dataset (Dataset) – 用于创建数据加载器的一组数据。如果数据集是DynamicItemDataset,则默认使用PaddedBatch作为collate_fn,除非在loader_kwargs中另有指定。
stage (Stage) – 实验的阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
ckpt_prefix (str, None) – 用于SaveableDataLoader检查点名称的前缀。阶段名称会添加到此前缀以创建完整的键。设置为None则不保存DataLoader。
**loader_kwargs (dict) – 传递给DataLoader的额外关键字参数。 例如,batch_size, num_workers, pin_memory。
- Return type:
输入数据集的DataLoader
- on_fit_start()[source]
在
fit()开始时调用,如果distributed_count > 0并且后端是ddp,则在多个进程上调用。默认实现编译jit模块,初始化优化器,并加载最新的检查点以恢复训练。
- init_optimizers()[source]
在
on_fit_start()期间调用,在参数完全配置后(例如DDP,jit)初始化优化器。此方法的默认实现依赖于在初始化时传递的优化器类,该类仅接受参数列表(例如,lambda 或部分函数定义)。这会创建一个优化所有可训练参数的单一优化器。
如果有多个优化器,请覆盖此类。
- zero_grad(set_to_none=False)[source]
将所有优化的
torch.Tensor``s 的梯度设置为零 如果 ``set_to_none=False(默认),否则设置为None。将梯度设置为None应该可以节省内存,例如在
evaluate()期间,因此可以使用更大的批次。
- on_evaluate_start(max_key=None, min_key=None)[source]
在
evaluate()开始时被调用默认实现加载性能最佳的检查点进行评估,基于存储的指标。
- fit_batch(batch)[source]
拟合一个批次,重写以进行多次更新。
默认实现依赖于一些具有特定行为的方法被定义:
compute_forward()compute_objectives()optimizers_step()
还取决于在初始化时传递的优化器。
- Parameters:
batch (list of torch.Tensors) – 用于训练的数据批次。默认实现假设此批次有两个元素:输入和目标。
- Return type:
分离损失
- check_loss_isfinite(loss)[source]
检查损失是否为有限值。
如果损失不是有限的,记录一条有用的消息并增加
nonfinite_count。 如果nonfinite_count超过了--nonfinite_patience阈值,停止训练 并抛出一个错误。当损失变为NaN或inf时,而参数和梯度保持有限时,此检查特别有用。它有助于防止在训练过程中陷入无限循环。
- Parameters:
loss (tensor) – 在调用
backward()之后但在优化器step()之前的损失张量。
- on_fit_batch_start(batch, should_step)[source]
在
fit_batch()开始时调用。此方法不在AMP上下文管理器下调用。不要假设输入批次会自动转换为较低精度(例如fp16)。
- Parameters:
batch (list of torch.Tensors) – 用于训练的数据批次。默认实现假设 此批次有两个元素:输入和目标。
should_step (boolean) – 是否调用了optimizer.step()。
- evaluate_batch(batch, stage)[source]
评估一个批次,覆盖以进行与训练不同的过程。
默认实现依赖于两个具有特定行为的方法:
compute_forward()compute_objectives()
- fit(epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})[source]
迭代周期和数据集以改进目标。
依赖于多个可以(或应该)被重写的函数的存在。以下方法被使用并期望具有某种行为:
fit_batch()evaluate_batch()update_average()
如果初始化时使用了 distributed_count > 0 并且 distributed_backend 是 ddp,这通常会处理多进程逻辑,比如将训练数据分割为每个设备的子集,并且只在主进程上保存检查点。
- Parameters:
epoch_counter (可迭代对象) – 每次调用应返回一个表示纪元计数的整数。
train_set (Dataset, DataLoader) – 用于训练的一组数据。如果提供了Dataset,则会自动创建一个DataLoader。如果提供了DataLoader,则直接使用它。
valid_set (Dataset, DataLoader) – 用于验证的一组数据。如果提供了Dataset,则会自动创建一个DataLoader。如果提供了DataLoader,则直接使用它。
progressbar (bool) – 是否在进度条中显示每个epoch的进度。
train_loader_kwargs (dict) – 传递给
make_dataloader()的参数,用于创建 train_loader (如果 train_set 是 Dataset,而不是 DataLoader)。 例如 batch_size, num_workers。 DataLoader 的所有参数都是有效的。valid_loader_kwargs (dict) – 传递给
make_dataloader()的参数,用于创建 valid_loader (如果 valid_set 是 Dataset,而不是 DataLoader)。 例如,batch_size, num_workers。 DataLoader 的所有参数都是有效的。
- Return type:
无
- evaluate(test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={})[source]
迭代测试集并评估大脑性能。默认情况下,加载性能最佳的检查点(使用检查点记录器记录)。
- Parameters:
test_set (Dataset, DataLoader) – 如果提供了DataLoader,则直接迭代。否则传递给
self.make_dataloader()。max_key (str) – 用于找到最佳检查点的键,传递给
on_evaluate_start()。min_key (str) – 用于找到最佳检查点的键,传递给
on_evaluate_start()。progressbar (bool) – 是否在进度条中显示进度。
test_loader_kwargs (dict) – 如果
test_set不是 DataLoader,则传递给make_dataloader()的 Kwargs。注意:loader_kwargs["ckpt_prefix"]会自动被覆盖为None(以便测试 DataLoader 不会添加到检查点器中)。
- Return type:
平均测试损失