torch.utils.data.sampler 的源代码
import torch
from torch import Tensor
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
__all__ = [
"BatchSampler",
"RandomSampler",
"Sampler",
"SequentialSampler",
"SubsetRandomSampler",
"WeightedRandomSampler",
]
T_co = TypeVar('T_co', covariant=True)
[docs]class Sampler(Generic[T_co]):
r"""所有采样器的基类。
每个采样器子类都必须提供一个 :meth:`__iter__` 方法,提供一种迭代数据集元素索引或索引列表(批次)的方式,以及一个 :meth:`__len__` 方法
返回返回迭代器的长度。
参数:
data_source (Dataset): 此参数未使用,将在 2.2.0 中删除。
您可能仍然有自定义实现使用它。
示例:
>>> # xdoctest: +SKIP
>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>> def __init__(self, data: List[str]) -> None:
>>> self.data = data
>>>
>>> def __len__(self) -> int:
>>> return len(self.data)
>>>
>>> def __iter__(self) -> Iterator[int]:
>>> sizes = torch.tensor([len(x) for x in self.data])
>>> yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>> def __init__(self, data: List[str], batch_size: int) -> None:
>>> self.data = data
>>> self.batch_size = batch_size
>>>
>>> def __len__(self) -> int:
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>> def __iter__(self) -> Iterator[List[int]]:
>>> sizes = torch.tensor([len(x) for x in self.data])
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>> yield batch.tolist()
.. 注意:: :meth:`__len__` 方法并不是严格要求由
:class:`~torch.utils.data.DataLoader` 使用,但在任何
涉及 :class:`~torch.utils.data.DataLoader` 长度的计算中都期望使用。
"""
def __init__(self, data_source: Optional[Sized] = None) -> None:
if data_source is not None:
import warnings
warnings.warn("`data_source` 参数未使用,将在 2.2.0 中删除。"
"您可能仍然有自定义实现使用它。")
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
# 注意 [ Python 抽象基类中缺少默认的 `__len__` ]
#
# 很多时候我们有一个抽象类表示数据的集合/可迭代对象,例如 `torch.utils.data.Sampler`,其子类可以选择性地
# 实现一个 `__len__` 方法。在这种情况下,我们必须确保不提供默认实现,因为两种直接的默认
# 实现都有其问题:
#
# + `return NotImplemented`:
# 调用 `len(subclass_instance)` 会引发:
# TypeError: 'NotImplementedType' 对象不能解释为整数
#
# + `raise NotImplementedError()`:
# 这会阻止触发某些回退行为。例如,内置的 `list(X)` 首先尝试调用 `len(X)`,如果方法未找到或返回 `NotImplemented`,
# 则执行不同的代码路径,而引发 `NotImplementedError` 会导致错误传播并使调用失败,
# 而它本可以使用 `__iter__` 完成调用。
#
# 因此,唯一合理的做法是
#
# + **不** 提供默认的 `__len__`。
#
# + 引发 `TypeError`,这是 Python 在用户调用对象上未定义的方法时使用的行为。
# (@ssnl 验证了这在至少 Python 3.7 中有效。)
[docs]class SequentialSampler(Sampler[int]):
r"""按顺序采样元素,始终保持相同的顺序。
参数:
data_source (Dataset): 要从中采样的数据集
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
[docs]class RandomSampler(Sampler[int]):
r"""随机采样元素。如果不进行替换,则从打乱的数据集中采样。
如果进行替换,则用户可以指定 :attr:`num_samples` 来抽取。
参数:
data_source (Dataset): 要从中采样的数据集
replacement (bool): 如果为 ``True``,则进行替换采样,默认=``False``
num_samples (int): 要抽取的样本数量,默认=`len(dataset)`。
generator (Generator): 采样中使用的生成器。
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = <span