speechbrain.dataio.dataloader 模块

PyTorch 兼容的数据加载器

本质上,我们通过增加保存数据加载状态的能力来扩展PyTorch DataLoader,以便可以在一个epoch的中间保存检查点。

Example

>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
...     # Here you would process the data:
...     rainfall_amount_prediction = data_point * 4.
...     # Now, imagine the experiment gets killed on the fifth batch:
...     if i == 4:
...         break
...     # Luckily, you had just saved a checkpoint:
...     if i == 3:
...         _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
Authors:
  • 阿库·柔赫 2020

摘要

类:

LoopedLoader

无限循环底层可迭代对象,具有名义上的周期长度

SaveableDataLoader

PyTorch DataLoader 的可保存版本。

函数:

distributed_loader_specifics

在必要时为DDP准备loader_kwargs。

make_dataloader

使用SpeechBrain默认设置创建一个基本的DataLoader。

参考

speechbrain.dataio.dataloader.distributed_loader_specifics(distributed_launch, rank, dataset, loader_kwargs)[source]

在必要时为DDP准备loader_kwargs。

Parameters:
  • distributed_launch (bool) – DDP 标志

  • rank (int) – DDP中的节点排名

  • dataset (Dataset) – 用于创建DataLoader的数据集。

  • loader_kwargs (dict) – 传递给DataLoader的关键字参数,有关选项请参见PyTorch DataLoader。

Returns:

增强的关键字参数传递给DataLoader

Return type:

loader_kwargs

speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]

使用SpeechBrain默认设置创建一个基本的DataLoader。

对于返回字典的DynamicItemDatasets,使用PaddedBatch作为默认的collate_fn。

洗牌操作由ReproducibleRandomSampler实现。

如果数据集不是可迭代数据集,则数据加载器是SaveableDataLoader。

如果数据集是 webdataset.dataset.Composable,设置默认的 batch_size = None。

也可以持续循环遍历底层的数据加载器, 并在名义上的epoch长度处停止迭代。

Parameters:
  • dataset (Dataset) – 用于创建DataLoader的数据集。

  • looped_nominal_epoch (None, int) – 如果给定一个整数,无限循环底层DataLoader并设置一个名义上的epoch长度,以批次(或DataLoader产生的任何内容)为单位。

  • **loader_kwargs (dict) – 传递给DataLoader的关键字参数,有关选项请参见PyTorch DataLoader。

Returns:

  • DataLoader – 如果 looped_nominal_epoch 为 None

  • LoopedLoader – 如果 looped_nominal_epoch 不为 None

class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]

基础:DataLoader

PyTorch DataLoader 的可保存版本。

请参阅torch.utils.data.DataLoader以了解使用方法。这个类应该与PyTorch的基本DataLoader完全一样,但它可以使用SpeechBrain的Checkpointer进行检查点保存。

注意

1. 可保存性是通过一些不幸的稍微神奇的方式实现的。 2. 数据加载器在进入__iter__后无法恢复。通常这不是问题,因为恢复应该在训练开始之前进行。然而,在评估之前,通常也会恢复性能最好的检查点。因此,如果在进入__iter__后加载检查点,我们只是假设这是出于这个原因。会记录一个警告,但仅此而已。

class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]

基础类:object

无限循环一个基础的可迭代对象,具有名义的周期长度

这对于处理IterableDatasets特别有用,尤其是webdataset风格的加载。我们建议在webdataset IterableDataset实例上使用.repeat(),这样底层的数据加载器自然可以永远继续。

Parameters:
  • loader (iterable) – 一个DataLoader或其他可迭代对象,可以重复循环。

  • epoch_length (int) – 名义周期的长度。在此步骤数之后,引发 StopIteration

  • batchsize_fn (可调用) – 用于确定批量大小的函数,默认为 BatchsizeGuesser

save(path)[source]

保存所需的信息。

load(path, end_of_epoch=True)[source]

加载所需的信息。