Shortcuts

torch.utils.data.distributed 的源代码

import math
from typing import TypeVar, Optional, Iterator

import torch
from . import Sampler, Dataset
import torch.distributed as dist

__all__ = ["DistributedSampler", ]

T_co = TypeVar('T_co', covariant=True)


[docs]class DistributedSampler(Sampler[T_co]): r"""采样器,限制数据加载到数据集的一个子集。 在与 :class:`torch.nn.parallel.DistributedDataParallel` 结合使用时特别有用。在这种情况下,每个 进程可以传递一个 :class:`~torch.utils.data.DistributedSampler` 实例作为 :class:`~torch.utils.data.DataLoader` 的采样器,并加载一个独占于它的原始数据集的子集。 .. 注意:: 假设数据集的大小是恒定的,并且它的任何实例总是 以相同的顺序返回相同的元素。 参数: dataset: 用于采样的数据集。 num_replicas (int, 可选): 参与分布式训练的进程数。 默认情况下,从当前分布式组中检索 :attr:`world_size`。 rank (int, 可选): 当前进程在 :attr:`num_replicas` 中的排名。 默认情况下,从当前分布式组中检索 :attr:`rank`。 shuffle (bool, 可选): 如果为 ``True`` (默认),采样器将打乱索引。 seed (int, 可选): 如果 :attr:`shuffle=True`,用于打乱采样器的随机种子。 这个数字在分布式组中的所有进程中应该相同。默认值: ``0``。 drop_last (bool, 可选): 如果为 ``True``,采样器将丢弃数据的尾部,使其在副本数量上均匀可分。 如果为 ``False``,采样器将添加额外的索引,使数据在副本上均匀可分。默认值: ``False``。 .. 警告:: 在分布式模式下,在每个 epoch 开始时调用 :meth:`set_epoch` 方法 **在创建 :class:`DataLoader` 迭代器之前** 是必要的,以确保在多个 epoch 中正确打乱。否则, 将始终使用相同的顺序。 示例:: >>> # xdoctest: +SKIP >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # 如果数据集长度在副本数量上均匀可分,则无需丢弃任何数据,因为数据集将被均匀分割。 if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # 分割到最近的可均匀分割的长度。 # 这是为了确保在使用此采样器时每个 rank 接收到相同数量的数据。 self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self) -> Iterator[T_co]: if self.shuffle: # 基于 epoch 和 seed 确定性地打乱 g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator<
优云智算