Shortcuts

torch.utils.data.sampler 的源代码

import torch
from torch import Tensor

from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union

__all__ = [
    "BatchSampler",
    "RandomSampler",
    "Sampler",
    "SequentialSampler",
    "SubsetRandomSampler",
    "WeightedRandomSampler",
]

T_co = TypeVar('T_co', covariant=True)


[docs]class Sampler(Generic[T_co]): r"""所有采样器的基类。 每个采样器子类都必须提供一个 :meth:`__iter__` 方法,提供一种迭代数据集元素索引或索引列表(批次)的方式,以及一个 :meth:`__len__` 方法 返回返回迭代器的长度。 参数: data_source (Dataset): 此参数未使用,将在 2.2.0 中删除。 您可能仍然有自定义实现使用它。 示例: >>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist() .. 注意:: :meth:`__len__` 方法并不是严格要求由 :class:`~torch.utils.data.DataLoader` 使用,但在任何 涉及 :class:`~torch.utils.data.DataLoader` 长度的计算中都期望使用。 """ def __init__(self, data_source: Optional[Sized] = None) -> None: if data_source is not None: import warnings warnings.warn("`data_source` 参数未使用,将在 2.2.0 中删除。" "您可能仍然有自定义实现使用它。") def __iter__(self) -> Iterator[T_co]: raise NotImplementedError
# 注意 [ Python 抽象基类中缺少默认的 `__len__` ] # # 很多时候我们有一个抽象类表示数据的集合/可迭代对象,例如 `torch.utils.data.Sampler`,其子类可以选择性地 # 实现一个 `__len__` 方法。在这种情况下,我们必须确保不提供默认实现,因为两种直接的默认 # 实现都有其问题: # # + `return NotImplemented`: # 调用 `len(subclass_instance)` 会引发: # TypeError: 'NotImplementedType' 对象不能解释为整数 # # + `raise NotImplementedError()`: # 这会阻止触发某些回退行为。例如,内置的 `list(X)` 首先尝试调用 `len(X)`,如果方法未找到或返回 `NotImplemented`, # 则执行不同的代码路径,而引发 `NotImplementedError` 会导致错误传播并使调用失败, # 而它本可以使用 `__iter__` 完成调用。 # # 因此,唯一合理的做法是 # # + **不** 提供默认的 `__len__`。 # # + 引发 `TypeError`,这是 Python 在用户调用对象上未定义的方法时使用的行为。 # (@ssnl 验证了这在至少 Python 3.7 中有效。)
[docs]class SequentialSampler(Sampler[int]): r"""按顺序采样元素,始终保持相同的顺序。 参数: data_source (Dataset): 要从中采样的数据集 """ data_source: Sized def __init__(self, data_source: Sized) -> None: self.data_source = data_source def __iter__(self) -> Iterator[int]: return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source)
[docs]class RandomSampler(Sampler[int]): r"""随机采样元素。如果不进行替换,则从打乱的数据集中采样。 如果进行替换,则用户可以指定 :attr:`num_samples` 来抽取。 参数: data_source (Dataset): 要从中采样的数据集 replacement (bool): 如果为 ``True``,则进行替换采样,默认=``False`` num_samples (int): 要抽取的样本数量,默认=`len(dataset)`。 generator (Generator): 采样中使用的生成器。 """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = <span
优云智算