torch_frame.data.loader 的源代码

from __future__ import annotations

import torch

from torch_frame.data import Dataset, TensorFrame
from torch_frame.typing import IndexSelectType


[docs]class DataLoader(torch.utils.data.DataLoader): r"""A data loader which creates mini-batches from a :class:`torch_frame.Dataset` or :class:`torch_frame.TensorFrame` object. .. code-block:: python import torch_frame dataset = ... loader = torch_frame.data.DataLoader( dataset, batch_size=512, shuffle=True, ) Args: dataset (Dataset or TensorFrame): The dataset or tensor frame from which to load the data. *args (optional): Additional arguments of :class:`torch.utils.data.DataLoader`. **kwargs (optional): Additional keyword arguments of :class:`torch.utils.data.DataLoader`. """ def __init__( self, dataset: Dataset | TensorFrame, *args, **kwargs, ): kwargs.pop('collate_fn', None) if isinstance(dataset, Dataset): self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame else: self.tensor_frame: TensorFrame = dataset super().__init__( range(len(dataset)), *args, collate_fn=self.collate_fn, **kwargs, ) def collate_fn(self, index: IndexSelectType) -> TensorFrame: return self.tensor_frame[index]