speechbrain.dataio.dataloader 模块
PyTorch 兼容的数据加载器
本质上,我们通过增加保存数据加载状态的能力来扩展PyTorch DataLoader,以便可以在一个epoch的中间保存检查点。
Example
>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
... # Here you would process the data:
... rainfall_amount_prediction = data_point * 4.
... # Now, imagine the experiment gets killed on the fifth batch:
... if i == 4:
... break
... # Luckily, you had just saved a checkpoint:
... if i == 3:
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
- Authors:
阿库·柔赫 2020
摘要
类:
无限循环底层可迭代对象,具有名义上的周期长度 |
|
PyTorch DataLoader 的可保存版本。 |
函数:
在必要时为DDP准备loader_kwargs。 |
|
使用SpeechBrain默认设置创建一个基本的DataLoader。 |
参考
- speechbrain.dataio.dataloader.distributed_loader_specifics(distributed_launch, rank, dataset, loader_kwargs)[source]
在必要时为DDP准备loader_kwargs。
- speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]
使用SpeechBrain默认设置创建一个基本的DataLoader。
对于返回字典的DynamicItemDatasets,使用PaddedBatch作为默认的collate_fn。
洗牌操作由ReproducibleRandomSampler实现。
如果数据集不是可迭代数据集,则数据加载器是SaveableDataLoader。
如果数据集是 webdataset.dataset.Composable,设置默认的 batch_size = None。
也可以持续循环遍历底层的数据加载器, 并在名义上的epoch长度处停止迭代。
- Parameters:
- Returns:
DataLoader – 如果 looped_nominal_epoch 为 None
LoopedLoader – 如果 looped_nominal_epoch 不为 None
- class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]
基础:
DataLoaderPyTorch DataLoader 的可保存版本。
请参阅
torch.utils.data.DataLoader以了解使用方法。这个类应该与PyTorch的基本DataLoader完全一样,但它可以使用SpeechBrain的Checkpointer进行检查点保存。注意
1. 可保存性是通过一些不幸的稍微神奇的方式实现的。 2. 数据加载器在进入__iter__后无法恢复。通常这不是问题,因为恢复应该在训练开始之前进行。然而,在评估之前,通常也会恢复性能最好的检查点。因此,如果在进入__iter__后加载检查点,我们只是假设这是出于这个原因。会记录一个警告,但仅此而已。
- class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]
基础类:
object无限循环一个基础的可迭代对象,具有名义的周期长度
这对于处理IterableDatasets特别有用,尤其是webdataset风格的加载。我们建议在webdataset IterableDataset实例上使用
.repeat(),这样底层的数据加载器自然可以永远继续。- Parameters:
loader (iterable) – 一个DataLoader或其他可迭代对象,可以重复循环。
epoch_length (int) – 名义周期的长度。在此步骤数之后,引发 StopIteration
batchsize_fn (可调用) – 用于确定批量大小的函数,默认为
BatchsizeGuesser