torch.utils.data.dataloader 的源代码
r"""DataLoader 及其相关迭代器的定义,这些迭代器是 _BaseDataLoaderIter 的子类。
为了支持这两个类,在 `./_utils` 中我们定义了许多实用方法和
在多进程中运行的函数。例如,数据加载工作循环位于 `./_utils/worker.py` 中。
"""
import functools
import itertools
import logging
import os
import queue
import threading
import warnings
from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union
import multiprocessing as python_multiprocessing
import torch
import torch.distributed as dist
import torch.multiprocessing as multiprocessing
import torch.utils.data.graph_settings
from torch._utils import ExceptionWrapper
from . import (
IterDataPipe,
MapDataPipe,
IterableDataset,
Sampler,
SequentialSampler,
RandomSampler,
BatchSampler,
Dataset,)
from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
from . import _utils
__all__ = [
"DataLoader",
"get_worker_info",
"default_collate",
"default_convert",
]
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]
# 理想情况下,我们可以通过 `collate_fn` 的返回类型来参数化 `DataLoader`,但目前没有办法让用户不传入自定义的 'collate_fn' 时设置默认值。
# 参见 https://github.com/python/mypy/issues/3737。
_collate_fn_t = Callable[[List[T]], Any]
# 这些函数曾经定义在这个文件中。然而,它被移动到了 `_utils/collate.py`。尽管从用户的角度来看,这很难访问
# (用户必须显式地直接 `import torch.utils.data.dataloader`),但可能已经有用户代码在使用它。这个别名保持了向后兼容性。
default_collate: _collate_fn_t = _utils.collate.default_collate
default_convert = _utils.collate.default_convert
get_worker_info = _utils.worker.get_worker_info
logger = logging.getLogger(__name__)
class _DatasetKind:
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
class _InfiniteConstantSampler(Sampler):
r"""类似于 ``itertools.repeat(None, None)``。
用于 :class:`~torch.utils.data.IterableDataset` 的采样器。
"""
def __iter__(self):
while True:
yield None
def _get_distributed_settings():
if dist.is_available() and dist.is_initialized():
return dist.get_world_size(), dist.get_rank()
else:
return 1, 0
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
global_worker_id = worker_id
info = torch.utils.data.get_worker_info()
assert info is not None
total_workers = info.num_workers
datapipe = info.dataset
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
# 为了在分布式进程之间均匀分配元素,我们应该先在分布式进程上分片数据,然后在工作进程上分片数据
total_workers *= world_size
global_worker_id = global_worker_id * world_size + rank_id
# 为了向后兼容,使用默认的 SHARDING_PRIORITIES
torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
if worker_init_fn is not None:
worker_init_fn(worker_id)
def _share_dist_seed(generator, pg):
_shared_seed = torch.empty<