Shortcuts

torch.utils.data.dataset 的源代码

```html
import bisect
import itertools
import math
import warnings
from typing import (
    cast,
    Dict,
    Generic,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)

# torch/__init__.pyi 中没有 'default_generator'
from torch import default_generator, randperm

from ... import Generator, Tensor

__all__ = [
    "Dataset",
    "IterableDataset",
    "TensorDataset",
    "StackDataset",
    "ConcatDataset",
    "ChainDataset",
    "Subset",
    "random_split",
]

T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
T_dict = Dict[str, T_co]
T_tuple = Tuple[T_co, ...]
T_stack = TypeVar("T_stack", T_tuple, T_dict)


[docs]class Dataset(Generic[T_co]): r"""表示一个 :class:`Dataset` 的抽象类。 所有表示从键到数据样本的映射的数据集都应该继承它。所有子类都应该重写 :meth:`__getitem__`,支持获取给定键的数据样本。子类也可以选择性地重写 :meth:`__len__`,许多 :class:`~torch.utils.data.Sampler` 实现和 :class:`~torch.utils.data.DataLoader` 的默认选项都期望返回数据集的大小。子类也可以选择性地实现 :meth:`__getitems__`,以加速批量样本加载。此方法接受批量样本的索引列表并返回样本列表。 .. 注意:: :class:`~torch.utils.data.DataLoader` 默认构造一个生成整数索引的索引采样器。要使其与具有非整数索引/键的映射样式数据集一起工作,必须提供自定义采样器。 """ def __getitem__(self, index) -> T_co: raise NotImplementedError("Dataset 的子类应该实现 __getitem__。") # def __getitems__(self, indices: List) -> List[T_co]: # 未实现以防止在 torch.utils.data._utils.fetch._MapDatasetFetcher 中出现误报 def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": return ConcatDataset([self, other])
# 没有 `def __len__(self)` 的默认实现? # 参见 NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # 在 pytorch/torch/utils/data/sampler.py 中
[docs]class IterableDataset(Dataset[T_co], Iterable[T_co]): r"""一个可迭代的数据集。 所有表示数据样本迭代的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用。 所有子类都应该重写 :meth:`__iter__`,它将返回此数据集中的样本迭代器。 当子类与 :class:`~torch.utils.data.DataLoader` 一起使用时,数据集中的每个项目将从 :class:`~torch.utils.data.DataLoader` 迭代器中生成。当 :attr:`num_workers > 0` 时,每个工作进程将有一个不同的数据集对象副本,因此通常希望独立配置每个副本以避免从工作进程返回重复数据。:func:`~torch.utils.data.get_worker_info` 在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的 :meth:`__iter__` 方法或 :class:`~torch.utils.data.DataLoader` 的 :attr:`worker_init_fn` 选项中使用,以修改每个副本的行为。 示例 1:在 :meth:`__iter__` 中跨所有工作进程拆分工作负载:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "此示例代码仅适用于 end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # 单进程数据加载,返回完整迭代器 ... iter_start = self.start ... iter_end = self.end ... else: # 在工作进程中 ... # 拆分工作负载 ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # 应该给出与 range(3, 7) 相同的数据集,即 [3, 4, 5, 6]。 >>> ds = MyIterableDataset(start=3, end=7) >>> # 单进程加载 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # 多进程加载,使用两个工作进程 >>> # 工作进程 0 获取 [3, 4]。 工作进程 1 获取 [5, 6]。 >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # 使用更多工作进程 >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] 示例 2:使用 :attr:`worker_init_fn` 跨所有工作进程拆分工作负载:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "此示例代码仅适用于 end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # 应该给出与 range(3, 7) 相同的数据集,即 [3, 4, 5, 6]。 >>> ds = MyIterableDataset(start=3, end=7) >>> # 单进程加载 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # 直接进行多进程加载会返回重复数据 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # 定义一个 `worker
优云智算