Shortcuts

torch.utils.data

PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它表示一个数据集上的 Python 可迭代对象,并支持

这些选项由 DataLoader 的构造函数参数配置,其签名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下部分详细描述了这些选项的效果和用法。

数据集类型

构造函数中最重要的参数是 DataLoaderdataset,它指示从哪个数据集对象加载数据。PyTorch 支持两种不同类型的数据集:

映射样式数据集

映射风格的dataset是一种实现了__getitem__()__len__()协议的数据集,它表示从(可能是非整数的)索引/键到数据样本的映射。

例如,当使用 dataset[idx] 访问这样的数据集时,可以从磁盘上的文件夹中读取第 idx 张图像及其对应的标签。

查看 Dataset 了解更多详情。

可迭代样式数据集

可迭代样式的数据集是IterableDataset子类的实例,它实现了__iter__()协议,并表示数据样本的可迭代对象。这种类型的数据集特别适用于随机读取成本高昂甚至不可能的情况,并且批量大小取决于获取的数据。

例如,这样的数据集,当调用 iter(dataset) 时,可以返回从数据库、远程服务器或甚至实时生成的日志中读取的数据流。

有关更多详细信息,请参阅 IterableDataset

注意

在使用 IterableDataset 进行 多进程数据加载 时。相同的 数据集对象在每个工作进程中被复制,因此必须对副本进行不同的配置以避免数据重复。请参阅 IterableDataset 文档以了解如何实现这一点。

数据加载顺序和Sampler

对于可迭代风格的数据集,数据加载顺序完全由用户定义的可迭代对象控制。这使得实现块读取和动态批量大小(例如,通过每次生成一个批量样本)变得更加容易。

本节的其余部分涉及 映射样式数据集的情况。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的序列。 它们表示数据集索引上的可迭代对象。例如,在 随机梯度下降(SGD)的常见情况下,Sampler 可以随机排列索引列表 并逐个生成每个索引,或者为小批量 SGD 生成少量索引。

将根据传递给 DataLoadershuffle 参数自动构建顺序或随机采样器。 或者,用户可以使用 sampler 参数来指定自定义的 Sampler 对象,该对象每次生成下一个要获取的索引/键。

一个自定义的 Sampler,它可以一次生成一批索引列表,可以作为 batch_sampler 参数传递。 自动批处理也可以通过 batch_sizedrop_last 参数启用。有关此内容的更多详细信息,请参阅 下一节

注意

既不兼容 sampler 也不兼容 batch_sampler 与可迭代风格的数据集,因为这类数据集没有键或索引的概念。

加载批处理和非批处理数据

DataLoader 支持通过参数 batch_sizedrop_lastbatch_samplercollate_fn(具有默认函数)自动将单个获取的数据样本整理成批次。

自动批处理(默认)

这是最常见的情况,对应于获取一批数据并将它们整理成批量样本,即包含张量的样本,其中一个维度是批次维度(通常是第一个)。

batch_size(默认值为 1)不为 None 时,数据加载器会生成批量样本而不是单个样本。batch_sizedrop_last 参数用于指定数据加载器如何获取数据集键的批次。对于映射风格的数据集,用户可以另外指定 batch_sampler,它一次生成一个键列表。

注意

batch_sizedrop_last 参数主要用于从 sampler 构建一个 batch_sampler。对于映射风格的数据集,sampler 要么由用户提供,要么根据 shuffle 参数构建。对于可迭代风格的数据集,sampler 是一个虚拟的无限采样器。有关采样器的更多详细信息,请参阅 此部分

注意

当从 可迭代风格的数据集中获取数据时,使用 多进程数据加载drop_last 参数会丢弃每个工作进程数据集副本的最后一个非完整批次。

在使用采样器中的索引获取样本列表后,作为 collate_fn 参数传递的函数用于将样本列表整理成批次。

在这种情况下,从映射式数据集中加载大致相当于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

并且从可迭代风格的dataset加载大致相当于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以使用自定义的 collate_fn 来自定义数据整理,例如,将序列数据填充到批次的最大长度。请参阅 本节 以了解更多关于 collate_fn 的信息。

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者简单地加载单个样本。例如,直接加载批量数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批量大小是数据依赖的,或者程序设计为处理单个样本。在这些场景下,最好不要使用自动批处理(其中collate_fn用于整理样本),而是让数据加载器直接返回dataset对象的每个成员。

batch_sizebatch_sampler 都为 Nonebatch_sampler 的默认值已经是 None)时,自动批处理被禁用。从 dataset 获取的每个样本都使用作为 collate_fn 参数传递的函数进行处理。

当自动批处理被禁用时,默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch 张量,并保持其他内容不变。

在这种情况下,从映射式数据集中加载大致相当于:

for index in sampler:
    yield collate_fn(dataset[index])

并且从可迭代风格的dataset加载大致相当于:

for data in iter(dataset):
    yield collate_fn(data)

参见此部分以了解更多关于collate_fn的信息。

使用 collate_fn

当启用或禁用自动批处理时,collate_fn 的使用略有不同。

当自动批处理被禁用时collate_fn 会针对每个单独的数据样本进行调用,并且数据加载器迭代器会生成输出。在这种情况下,默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch 张量。

当启用自动批处理时collate_fn 会在每次调用时接收一个数据样本列表。它需要将输入样本整理成一个批次,以便从数据加载器迭代器中生成。本节的其余部分描述了默认的 collate_fn 的行为(default_collate())。

例如,如果每个数据样本由一个3通道图像和一个整数类标签组成,即数据集的每个元素返回一个元组 (image, class_index),默认的 collate_fn 将这些元组列表整理成一个元组,其中包含批量图像张量和批量类标签张量。特别是,默认的 collate_fn 具有以下属性:

  • 它总是在批次维度前添加一个新的维度。

  • 它自动将NumPy数组和Python数值转换为PyTorch张量。

  • 它保留了数据结构,例如,如果每个样本是一个字典,它将输出一个具有相同键集的字典,但值是批处理的张量(如果值不能转换为张量,则为列表)。对于listtuplenamedtuple等也是如此。

用户可以使用自定义的 collate_fn 来实现自定义批处理,例如,沿第一个维度以外的维度进行批处理,填充不同长度的序列,或添加对自定义数据类型的支持。

如果你遇到 DataLoader 的输出维度或类型与你的预期不符的情况,你可能需要检查你的 collate_fn

单进程和多进程数据加载

一个 DataLoader 默认使用单进程数据加载。

在Python进程中, 全局解释器锁(GIL) 阻止了在多个线程之间真正并行化Python代码。为了避免数据加载阻塞计算代码,PyTorch提供了一个简单的切换方式,通过将参数num_workers 设置为正整数来执行多进程数据加载。

单进程数据加载(默认)

在此模式下,数据获取是在初始化DataLoader的同一进程中完成的。因此,数据加载可能会阻塞计算。然而,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限时,或者当整个数据集很小并且可以完全加载到内存中时,此模式可能更可取。此外,单进程加载通常会显示更易读的错误跟踪,因此对于调试非常有用。

多进程数据加载

将参数 num_workers 设置为正整数将 启用多进程数据加载,并指定加载器工作进程的数量。

警告

经过多次迭代后,加载器工作进程将消耗与父进程相同的CPU内存,用于父进程中从工作进程访问的所有Python对象。如果数据集中包含大量数据(例如,在数据集构建时加载非常大的文件名列)和/或使用大量工作进程(总体内存使用量为工作进程数量 * 父进程大小),这可能会出现问题。最简单的解决方法是使用非引用计数的表示形式(如Pandas、Numpy或PyArrow对象)替换Python对象。有关此问题的原因及如何解决这些问题的示例代码,请查看问题 #13246了解更多详情。

在这种模式下,每次创建一个 DataLoader 的迭代器时(例如,当你调用 enumerate(dataloader)),会创建 num_workers 个工作进程。此时,datasetcollate_fnworker_init_fn 会被传递给每个 工作进程,用于初始化和获取数据。这意味着 数据集访问及其内部IO、转换 (包括 collate_fn)在工作进程中运行。

torch.utils.data.get_worker_info() 返回工作进程中的各种有用信息(包括工作进程ID、数据集副本、初始种子等),并在主进程中返回 None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在分片数据集时特别有用。

对于映射风格的(map-style)数据集,主进程使用 sampler生成索引并将它们发送给工作进程。因此,任何洗牌(shuffle)随机化都是在主进程中完成的,主进程通过分配加载索引来指导加载。

对于可迭代样式的数据集,由于每个工作进程都会获得dataset对象的副本,因此简单的多进程加载通常会导致数据重复。使用torch.utils.data.get_worker_info()和/或worker_init_fn,用户可以独立配置每个副本。(请参阅IterableDataset文档以了解如何实现这一点。)出于类似的原因,在多进程加载中,drop_last参数会丢弃每个工作进程的可迭代样式数据集副本的最后一个非完整批次。

当迭代结束时,或者当迭代器被垃圾回收时,工作线程将被关闭。

警告

通常不建议在多进程加载中返回CUDA张量,因为使用CUDA和在多进程中共享CUDA张量存在许多细微差别(参见多进程中的CUDA)。相反,我们建议使用自动内存固定(即设置pin_memory=True),这样可以实现快速数据传输到支持CUDA的GPU。

特定平台的行为

由于工作线程依赖于Python multiprocessing,因此在Windows上的工作线程启动行为与Unix不同。

  • 在 Unix 上,fork() 是默认的 multiprocessing 启动方法。 使用 fork(),子工作进程通常可以通过克隆的地址空间直接访问 dataset 和 Python 参数函数。

  • 在 Windows 或 MacOS 上,spawn() 是默认的 multiprocessing 启动方法。 使用 spawn(),会启动另一个解释器来运行您的主脚本, 随后运行接收 datasetcollate_fn 和其他参数的内部工作函数,这些参数通过 pickle 序列化传递。

这种单独的序列化意味着在使用多进程数据加载时,您应该采取两个步骤来确保与Windows兼容:

  • 将大部分主脚本的代码包裹在 if __name__ == '__main__': 块中, 以确保在每个工作进程启动时不会再次运行(最有可能导致错误)。你可以将数据集和 DataLoader 实例创建逻辑放在这里,因为它不需要在工作进程中重新执行。

  • 确保任何自定义的 collate_fnworker_init_fndataset 代码被声明为顶层定义,位于 __main__ 检查之外。这确保它们在工作进程中可用。 (这是必需的,因为函数仅作为引用进行序列化,而不是 bytecode。)

多进程数据加载中的随机性

默认情况下,每个工作线程的 PyTorch 种子将被设置为 base_seed + worker_id, 其中 base_seed 是由主进程使用其随机数生成器(RNG)生成的长整数(因此,必须消耗一个 RNG 状态)或指定的 generator。然而,在初始化工作线程时,其他库的种子可能会重复,导致每个工作线程返回相同的随机数。(参见常见问题解答中的这一部分。)

worker_init_fn 中,您可以通过 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问为每个 worker 设置的 PyTorch 种子,并在数据加载之前使用它来为其他库设置种子。

内存固定

从固定(页锁定)内存中发起的主机到GPU的拷贝速度要快得多。有关何时以及如何一般使用固定内存的更多详细信息,请参阅使用固定内存缓冲区

对于数据加载,将 pin_memory=True 传递给 DataLoader 会自动将获取的数据 张量放入固定内存中,从而加快数据传输到支持 CUDA 的 GPU。

默认的内存固定逻辑仅识别张量、映射和包含张量的可迭代对象。默认情况下,如果固定逻辑遇到一个自定义类型的批次(如果您有一个返回自定义批次类型的 collate_fn,则会发生这种情况),或者如果批次的每个元素都是自定义类型,固定逻辑将无法识别它们,并且将返回该批次(或这些元素)而不固定内存。要为自定义批次或数据类型启用内存固定,请在您的自定义类型上定义一个 pin_memory() 方法。

请参见下面的示例。

示例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # 自定义类型上的自定义内存固定方法
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[源代码]

数据加载器将数据集和采样器结合在一起,并提供对给定数据集的可迭代访问。

The DataLoader 支持映射风格和迭代风格的 数据集,支持单进程或多进程加载,自定义加载顺序以及可选的自动批处理(排序)和内存固定。

请参阅torch.utils.data文档页面以获取更多详细信息。

Parameters
  • 数据集 (数据集) – 从中加载数据的数据集。

  • batch_size (int, 可选) – 每个批次加载多少个样本 (默认值: 1)。

  • shuffle (bool, 可选) – 设置为 True 以在每个 epoch 重新洗牌数据(默认值:False)。

  • 采样器 (采样器可迭代对象可选) – 定义从数据集中抽取样本的策略。可以是任何实现了 Iterable__len__ 的对象。如果指定,shuffle 必须不指定。

  • batch_sampler (采样器可迭代对象可选) – 类似于 sampler,但 一次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, 可选) – 用于数据加载的子进程数量。0 表示数据将在主进程中加载。 (默认值: 0)

  • collate_fn (Callable, 可选) – 将样本列表合并成一个 小批量张量。 当从 映射式数据集中使用批量加载时使用。

  • pin_memory (bool, 可选) – 如果为True,数据加载器将在返回数据之前将张量复制到设备/CUDA固定内存中。如果你的数据元素是自定义类型,或者你的collate_fn返回的批次是自定义类型,请参见下面的示例。

  • drop_last (bool, 可选) – 设置为 True 以丢弃最后一个不完整的批次, 如果数据集大小不能被批次大小整除。如果 False 并且 数据集大小不能被批次大小整除,则最后一个批次 将会较小。(默认值: False)

  • 超时数值可选)——如果为正数,则表示从工作线程收集一批数据的超时值。应始终为非负数。(默认值:0

  • worker_init_fn (可调用对象, 可选) – 如果不是 None,这将在每个工作子进程上被调用,参数为工作id(一个在 [0, num_workers - 1] 范围内的整数),在种子设定和数据加载之前。(默认值:None

  • multiprocessing_context (strmultiprocessing.context.BaseContext, 可选) – 如果 None,将使用操作系统的默认 multiprocessing 上下文。(默认值: None)

  • 生成器 (torch.Generator, 可选) – 如果不是 None,这个随机数生成器将被 RandomSampler 用于生成随机索引,并被多进程用于生成工作线程的 base_seed。(默认值: None)

  • prefetch_factor (int, 可选, 仅关键字参数) – 每个工作线程预加载的批次数量。2 表示所有工作线程总共会预加载 2 * num_workers 个批次。(默认值取决于为 num_workers 设置的值。如果 num_workers=0 的值为 None。否则,如果 num_workers > 0 的默认值为 2)。

  • persistent_workers (bool, 可选) – 如果True,数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程的数据集实例存活。(默认值: False)

  • pin_memory_device (str, 可选) – 如果 pin_memoryTrue,则指定要锁定内存的设备。

警告

如果使用 spawn 启动方法,worker_init_fn 不能是一个不可序列化的对象,例如 lambda 函数。请参阅 多进程最佳实践 以了解更多关于 PyTorch 中多进程的详细信息。

警告

len(dataloader) 启发式方法是基于所使用的采样器的长度。 当 dataset 是一个 IterableDataset 时, 它改为基于 len(dataset) / batch_size 返回一个估计值,并根据 drop_last 进行适当的 四舍五入,无论多进程加载配置如何。这代表了 PyTorch 所能做出的最佳猜测,因为 PyTorch 信任用户 dataset 代码在正确处理多进程 加载以避免重复数据。

然而,如果分片导致多个工作线程拥有不完整的最后批次,这个估计仍然可能不准确,因为(1)一个原本完整的批次可能会被分成多个批次,以及(2)当设置drop_last时,可能会丢弃超过一个批次价值的样本。不幸的是,PyTorch通常无法检测到这些情况。

有关这两种数据集类型的更多详细信息以及IterableDataset如何与多进程数据加载交互,请参阅数据集类型

警告

参见可重复性,以及我的数据加载器工作线程返回相同的随机数,以及 多进程数据加载中的随机性笔记,了解与随机种子相关的问题。

class torch.utils.data.Dataset(*args, **kwds)[源代码]

表示一个Dataset的抽象类。

所有表示从键到数据样本映射的数据集都应该继承它。所有子类都应该重写__getitem__(),以支持获取给定键的数据样本。子类也可以选择性地重写__len__(),许多Sampler实现和DataLoader的默认选项都期望返回数据集的大小。子类也可以选择性地实现__getitems__(),以加速批量样本加载。此方法接受批量样本的索引列表并返回样本列表。

注意

DataLoader 默认构建一个索引采样器,该采样器生成整数索引。为了使其与具有非整数索引/键的映射样式数据集一起工作,必须提供一个自定义采样器。

class torch.utils.data.IterableDataset(*args, **kwds)[源代码]

一个可迭代的Dataset。

所有表示数据样本可迭代的数据集都应该继承它。 当数据来自流时,这种形式的数据集特别有用。

所有子类都应该重写 __iter__(),这将返回此数据集中样本的迭代器。

当一个子类与DataLoader一起使用时,数据集中的每个项目将从DataLoader迭代器中生成。当num_workers > 0时,每个工作进程将有一个不同的数据集对象副本,因此通常希望独立配置每个副本,以避免从工作进程返回重复数据。get_worker_info(),在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的__iter__()方法或DataLoaderworker_init_fn选项中使用,以修改每个副本的行为。

示例 1: 在 __iter__() 中将工作负载分配给所有工作线程:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "此示例代码仅适用于 end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # 单进程数据加载,返回完整的迭代器
...             iter_start = self.start
...             iter_end = self.end
...         else:  # 在 worker 进程中
...             # 分割工作负载
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # 应该返回与 range(3, 7) 相同的数据集,即 [3, 4, 5, 6]。
>>> ds = MyIterableDataset(start=3, end=7)

>>> # 单进程加载
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # 多进程加载,使用两个 worker 进程
>>> # Worker 0 获取了 [3, 4]。 Worker 1 获取了 [5, 6]。
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # 使用更多的 worker
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

示例 2:使用 worker_init_fn 在所有工作线程之间分配工作负载:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "此示例代码仅适用于 end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # 应该给出与 range(3, 7) 相同的数据集,即 [3, 4, 5, 6]。
>>> ds = MyIterableDataset(start=3, end=7)

>>> # 单进程加载
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # 直接进行多进程加载会产生重复数据
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # 定义一个 `worker_init_fn` 函数,用于配置每个数据集副本
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # 此工作进程中的数据集副本
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # 配置数据集仅处理拆分的工作负载
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # 使用自定义的 `worker_init_fn` 进行多进程加载
>>> # 工作进程 0 获取了 [3, 4]。 工作进程 1 获取了 [5, 6]。
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # 使用更多的工作进程
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[源代码]

包装张量的数据集。

每个样本将通过沿第一个维度索引张量来检索。

Parameters

*张量 (张量) – 具有相同第一维度大小的张量。

class torch.utils.data.StackDataset(*args, **kwargs)[源代码]

数据集作为多个数据集的堆叠。

此类用于将复杂输入数据的不同部分组装在一起,这些数据以数据集的形式提供。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
Parameters
  • *args (数据集) – 以元组形式返回的堆叠数据集。

  • **kwargs (数据集) – 以字典形式返回的堆叠数据集。

class torch.utils.data.ConcatDataset(datasets)[源代码]

作为多个数据集连接的Dataset。

此类可用于组合不同的现有数据集。

Parameters

数据集 (序列) – 要连接的数据集列表

class torch.utils.data.ChainDataset(datasets)[源代码]

用于链接多个 IterableDataset 的数据集。

此类可用于组合不同的现有数据集流。链式操作是实时完成的,因此使用此类连接大规模数据集将非常高效。

Parameters

数据集 (可迭代IterableDataset) – 要连接在一起的数据集

class torch.utils.data.Subset(dataset, indices)[源代码]

数据集在指定索引处的子集。

Parameters
  • 数据集 (数据集) – 整个数据集

  • indices (序列) – 在整个集合中为子集选择的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[源代码]

处理每个批次中元素集合类型的通用整理函数。

该函数还打开函数注册表以处理特定元素类型。default_collate_fn_map 为张量、numpy数组、数字和字符串提供默认的整理函数。

Parameters
  • batch – 要合并的单个批次

  • collate_fn_map (可选[字典[联合[类型, 元组[类型, ...]]], 可调用]]) – 从元素类型到相应collate函数的可选字典映射。 如果元素类型不在此字典中, 此函数将按照插入顺序遍历字典的每个键, 如果元素类型是键的子类,则调用相应的collate函数。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
>>> # 扩展此函数以处理张量批次
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # 通过就地修改 `default_collate_fn_map` 扩展 `default_collate`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

每个排序函数都需要一个位置参数用于批次,以及一个关键字参数用于排序函数的字典,如collate_fn_map

torch.utils.data.default_collate(batch)[源代码]

接收一批数据并将批次中的元素放入具有额外外维度的张量中 - 批次大小。

确切的输出类型可以是 torch.TensorSequencetorch.Tensortorch.Tensor 的集合,或者保持不变,具体取决于输入类型。 当在 DataLoader 中定义了 batch_sizebatch_sampler 时,此函数用作默认的整理函数。

这里是基于批次内元素类型的通用输入类型到输出类型的映射:

  • torch.Tensor -> torch.Tensor(添加了外部维度批次大小)

  • NumPy 数组 -> torch.Tensor

  • 浮点数 -> torch.Tensor

  • 整数 -> torch.Tensor

  • 字符串 -> 字符串 (不变)

  • 字节 -> 字节 (未改变)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • 序列[V1_i, V2_i, …] -> 序列[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

Parameters

batch – 要合并的单个批次

示例

>>> # 使用一批 `int` 的示例:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # 使用一批 `str` 的示例:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # 使用 `Map` 在批次中的示例:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # 使用 `NamedTuple` 在批次中的示例:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # 使用 `Tuple` 在批次中的示例:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # 使用 `List` 在批次中的示例:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # 扩展 `default_collate` 以处理特定类型的两种选项
>>> # 选项 1:编写自定义的 collate 函数并调用 `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # 一些自定义条件
...         return ...
...     else:  # 回退到 `default_collate`
...         return default_collate(batch)
>>> # 选项 2:就地修改 `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustoType, collate_customtype_fn)
>>> default_collate(batch)  # 自动处理 `CustomType`
torch.utils.data.default_convert(data)[源代码]

将每个 NumPy 数组元素转换为 torch.Tensor

如果输入是序列集合映射,它会尝试将每个元素转换为torch.Tensor。 如果输入不是NumPy数组,则保持不变。 当batch_samplerbatch_sizeDataLoader中未定义时,此功能用作默认的整理函数。

通用输入类型到输出类型的映射与default_collate()类似。更多详细信息请参见该部分的描述。

Parameters

数据 – 要转换的单个数据点

示例

>>> # 使用 `int` 的示例
>>> default_convert(0)
0
>>> # 使用 NumPy 数组的示例
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # 使用 NamedTuple 的示例
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # 使用 List 的示例
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[源代码]

返回有关当前 DataLoader 迭代器工作进程的信息。

当在工作线程中调用时,这将返回一个保证具有以下属性的对象:

  • id: 当前工作者的ID。

  • num_workers: 工作者的总数。

  • seed: 为当前工作进程设置的随机种子。此值由主进程的随机数生成器和工作进程的ID决定。有关更多详细信息,请参阅DataLoader的文档。

  • dataset: 在此进程中的数据集对象副本。请注意,这将与主进程中的对象不同,并且在不同进程中也是不同的对象。

在主进程中调用时,这将返回 None

注意

当在传递给DataLoaderworker_init_fn中使用时,此方法可用于为每个工作进程设置不同的配置,例如,使用worker_id来配置dataset对象以仅读取分片数据集的特定部分,或使用seed为数据集代码中使用的其他库设置种子。

Return type

可选[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[源代码]

随机将数据集拆分为给定长度的非重叠新数据集。

如果给定一个总和为1的分数列表,长度将自动计算为每个提供的分数的floor(frac * len(dataset))。

计算长度后,如果有任何余数,将以轮询方式将1个计数分配给长度,直到没有余数为止。

可选择固定生成器以获得可重复的结果,例如:

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Parameters
  • 数据集 (数据集) – 要分割的数据集

  • 长度 (序列) – 要生成的分割的长度或分数

  • 生成器 (生成器) – 用于随机排列的生成器。

Return type

列表[子集[T]]

class torch.utils.data.Sampler(data_source=None)[源代码]

所有采样器的基类。

每个 Sampler 子类都必须提供一个 __iter__() 方法,用于迭代数据集元素的索引或索引列表(批次),以及一个 __len__() 方法,用于返回迭代器的长度。

Parameters

数据源 (数据集) – 此参数未被使用,并将在2.2.0版本中移除。 您可能仍然有自定义实现使用它。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

The __len__() 方法并不是严格要求由 DataLoader 使用,但在任何涉及 DataLoader 长度的计算中都是预期的。

class torch.utils.data.SequentialSampler(data_source)[源代码]

按顺序采样元素,始终保持相同的顺序。

Parameters

数据源 (数据集) – 从中采样的数据集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[源代码]

随机抽取样本元素。如果不进行替换,则从打乱的数据集中进行抽样。

如果是有放回的,那么用户可以指定 num_samples 来抽取样本。

Parameters
  • 数据源 (数据集) – 从中采样的数据集

  • replacement (bool) – 如果为True,则按需进行有放回抽样,默认值为``False``

  • num_samples (int) – 要抽取的样本数量,默认为`len(dataset)`。

  • 生成器 (生成器) – 用于采样的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[源代码]

从给定的索引列表中随机抽取元素,不放回。

Parameters
  • indices (序列) – 一个索引序列

  • 生成器 (生成器) – 用于采样的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[源代码]

[0,..,len(weights)-1] 中按给定的概率(权重)抽取样本元素。

Parameters
  • weights (序列) – 一个权重序列,不需要总和为一

  • num_samples (int) – 要抽取的样本数量

  • replacement (bool) – 如果为True,则进行有放回抽样。如果不为真,则进行无放回抽样,这意味着当为某一行抽取一个样本索引时,该索引不能再次为该行抽取。

  • 生成器 (生成器) – 用于采样的生成器。

示例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[源代码]

包装另一个采样器以生成一小批索引。

Parameters
  • 采样器 (SamplerIterable) – 基础采样器。可以是任何可迭代对象

  • batch_size (int) – 小批量大小。

  • drop_last (bool) – 如果True,采样器将在其大小小于batch_size时丢弃最后一个批次

示例

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[源代码]

采样器,限制数据加载到数据集的一个子集。

它与 torch.nn.parallel.DistributedDataParallel 结合使用时特别有用。在这种情况下,每个 进程可以将一个 DistributedSampler 实例作为 DataLoader 采样器传递,并加载原始数据集的一个独占子集。

注意

假设数据集的大小是恒定的,并且它的任何实例总是以相同的顺序返回相同的元素。

Parameters
  • dataset – 用于采样的数据集。

  • num_replicas (int, 可选) – 参与分布式训练的进程数量。默认情况下,world_size 从当前分布式组中获取。

  • rank (int, 可选) – 当前进程在 num_replicas 中的排名。 默认情况下,rank 是从当前分布式组中获取的。

  • shuffle (bool, 可选) – 如果 True(默认),采样器将打乱索引。

  • 种子 (整数, 可选) – 如果shuffle=True,用于打乱采样器的随机种子。这个数字在分布式组中的所有进程中应该相同。默认值:0

  • drop_last (布尔值, 可选) – 如果True,则采样器将丢弃数据的尾部,使其在副本数量上均匀可分。如果False,采样器将添加额外的索引以使数据在副本上均匀可分。默认值:False

警告

在分布式模式下,在每个epoch开始时之前创建DataLoader迭代器之前调用set_epoch()方法是必要的,以确保在多个epoch中正确地进行洗牌。否则,将始终使用相同的顺序。

示例:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)
优云智算