torch.utils.data.distributed 的源代码
import math
from typing import TypeVar, Optional, Iterator
import torch
from . import Sampler, Dataset
import torch.distributed as dist
__all__ = ["DistributedSampler", ]
T_co = TypeVar('T_co', covariant=True)
[docs]class DistributedSampler(Sampler[T_co]):
r"""采样器,限制数据加载到数据集的一个子集。
在与 :class:`torch.nn.parallel.DistributedDataParallel` 结合使用时特别有用。在这种情况下,每个
进程可以传递一个 :class:`~torch.utils.data.DistributedSampler` 实例作为
:class:`~torch.utils.data.DataLoader` 的采样器,并加载一个独占于它的原始数据集的子集。
.. 注意::
假设数据集的大小是恒定的,并且它的任何实例总是
以相同的顺序返回相同的元素。
参数:
dataset: 用于采样的数据集。
num_replicas (int, 可选): 参与分布式训练的进程数。
默认情况下,从当前分布式组中检索 :attr:`world_size`。
rank (int, 可选): 当前进程在 :attr:`num_replicas` 中的排名。
默认情况下,从当前分布式组中检索 :attr:`rank`。
shuffle (bool, 可选): 如果为 ``True`` (默认),采样器将打乱索引。
seed (int, 可选): 如果 :attr:`shuffle=True`,用于打乱采样器的随机种子。
这个数字在分布式组中的所有进程中应该相同。默认值: ``0``。
drop_last (bool, 可选): 如果为 ``True``,采样器将丢弃数据的尾部,使其在副本数量上均匀可分。
如果为 ``False``,采样器将添加额外的索引,使数据在副本上均匀可分。默认值: ``False``。
.. 警告::
在分布式模式下,在每个 epoch 开始时调用 :meth:`set_epoch` 方法
**在创建 :class:`DataLoader` 迭代器之前** 是必要的,以确保在多个 epoch 中正确打乱。否则,
将始终使用相同的顺序。
示例::
>>> # xdoctest: +SKIP
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# 如果数据集长度在副本数量上均匀可分,则无需丢弃任何数据,因为数据集将被均匀分割。
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# 分割到最近的可均匀分割的长度。
# 这是为了确保在使用此采样器时每个 rank 接收到相同数量的数据。
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# 基于 epoch 和 seed 确定性地打乱
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator<