CIFAR10SuperPixelDataset
- class dgl.data.CIFAR10SuperPixelDataset(raw_dir=None, split='train', use_feature=False, force_reload=False, verbose=False, transform=None)[source]
基础类:
SuperPixelDataset
CIFAR10超像素数据集用于图分类任务。
基准GNN中的CIFAR10的DGL数据集,包含从原始CIFAR10图像转换而来的图。
参考 http://arxiv.org/abs/2003.00982
统计:
训练样本:50,000
测试示例:10,000
数据集图像的大小:32
- Parameters:
raw_dir (str) – 用于存储所有下载的原始数据集的目录。 默认值:“~/.dgl/”。
split (str) – 应从[“train”, “test”]中选择 默认值: “train”.
use_feature (bool) –
True: 从超像素位置 + 特征定义的邻接矩阵
False: 仅从超像素位置定义的邻接矩阵
默认值: False.
force_reload (bool) – 是否重新加载数据集。 默认值:False。
verbose (bool) – 是否打印进度信息。 默认值:False。
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> from dgl.data import CIFAR10SuperPixelDataset
>>> # CIFAR10 dataset >>> train_dataset = CIFAR10SuperPixelDataset(split="train") >>> len(train_dataset) 50000 >>> graph, label = train_dataset[0] >>> graph Graph(num_nodes=123, num_edges=984, ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
>>> # support tensor to be index when transform is None >>> # see details in __getitem__ function >>> import torch >>> idx = torch.tensor([0, 1, 2]) >>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset[0] Graph(num_nodes=123, num_edges=984, ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),
- __getitem__(idx)
获取第 idx 个样本。
- Parameters:
idx (int 或 tensor) – 样本索引。 当 transform 为 None 时,允许使用 1-D 张量作为 idx。
- Returns:
(
dgl.DGLGraph
, Tensor) – 存储在feat
字段中的节点特征及其标签的图。或
dgl.data.utils.Subset
– 指定索引处的数据集的子集
- __len__()
数据集中的示例数量。