Shortcuts

torch.utils.data.dataloader 的源代码

r"""DataLoader 及其相关迭代器的定义,这些迭代器是 _BaseDataLoaderIter 的子类。

为了支持这两个类,在 `./_utils` 中我们定义了许多实用方法和
在多进程中运行的函数。例如,数据加载工作循环位于 `./_utils/worker.py` 中。
"""

import functools
import itertools
import logging
import os
import queue
import threading
import warnings

from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union

import multiprocessing as python_multiprocessing
import torch
import torch.distributed as dist
import torch.multiprocessing as multiprocessing
import torch.utils.data.graph_settings

from torch._utils import ExceptionWrapper

from . import (
    IterDataPipe,
    MapDataPipe,
    IterableDataset,
    Sampler,
    SequentialSampler,
    RandomSampler,
    BatchSampler,
    Dataset,)

from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper

from . import _utils

__all__ = [
    "DataLoader",
    "get_worker_info",
    "default_collate",
    "default_convert",
]

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]

# 理想情况下,我们可以通过 `collate_fn` 的返回类型来参数化 `DataLoader`,但目前没有办法让用户不传入自定义的 'collate_fn' 时设置默认值。
# 参见 https://github.com/python/mypy/issues/3737。
_collate_fn_t = Callable[[List[T]], Any]


# 这些函数曾经定义在这个文件中。然而,它被移动到了 `_utils/collate.py`。尽管从用户的角度来看,这很难访问
# (用户必须显式地直接 `import torch.utils.data.dataloader`),但可能已经有用户代码在使用它。这个别名保持了向后兼容性。
default_collate: _collate_fn_t = _utils.collate.default_collate
default_convert = _utils.collate.default_convert

get_worker_info = _utils.worker.get_worker_info

logger = logging.getLogger(__name__)


class _DatasetKind:
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)


class _InfiniteConstantSampler(Sampler):
    r"""类似于 ``itertools.repeat(None, None)``。

    用于 :class:`~torch.utils.data.IterableDataset` 的采样器。
    """

    def __iter__(self):
        while True:
            yield None


def _get_distributed_settings():
    if dist.is_available() and dist.is_initialized():
        return dist.get_world_size(), dist.get_rank()
    else:
        return 1, 0


def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
    global_worker_id = worker_id
    info = torch.utils.data.get_worker_info()
    assert info is not None
    total_workers = info.num_workers
    datapipe = info.dataset
    assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
    # 为了在分布式进程之间均匀分配元素,我们应该先在分布式进程上分片数据,然后在工作进程上分片数据
    total_workers *= world_size
    global_worker_id = global_worker_id * world_size + rank_id
    # 为了向后兼容,使用默认的 SHARDING_PRIORITIES
    torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
    if worker_init_fn is not None:
        worker_init_fn(worker_id)


def _share_dist_seed(generator, pg):
    _shared_seed = torch.empty<