MNIST超像素数据集
- class dgl.data.MNISTSuperPixelDataset(raw_dir=None, split='train', use_feature=False, force_reload=False, verbose=False, transform=None)[source]
Bases:
SuperPixelDataset
用于图分类任务的MNIST超像素数据集。
基准GNN中的MNIST和CIFAR10的DGL数据集,包含从原始MNIST和CIFAR10图像转换而来的图。
Reference http://arxiv.org/abs/2003.00982
统计:
训练样本:60,000
测试示例:10,000
数据集图像的大小:28
- Parameters:
raw_dir (str) – Directory to store all the downloaded raw datasets. Default: “~/.dgl/”.
split (str) – Should be chosen from [“train”, “test”] Default: “train”.
use_feature (bool) –
True: Adj matrix defined from super-pixel locations + features
False: Adj matrix defined from super-pixel locations (only)
Default: False.
force_reload (bool) – Whether to reload the dataset. Default: False.
verbose (bool) – Whether to print out progress information. Default: False.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> from dgl.data import MNISTSuperPixelDataset
>>> # MNIST dataset >>> train_dataset = MNISTSuperPixelDataset(split="train") >>> len(train_dataset) 60000 >>> graph, label = train_dataset[0] >>> graph Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
>>> # support tensor to be index when transform is None >>> # see details in __getitem__ function >>> import torch >>> idx = torch.tensor([0, 1, 2]) >>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset[0] Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
- __getitem__(idx)
获取第 idx 个样本。
- Parameters:
idx (int or tensor) – The sample index. 1-D tensor as idx is allowed when transform is None.
- Returns:
(
dgl.DGLGraph
, Tensor) – Graph with node feature stored infeat
field and its label.or
dgl.data.utils.Subset
– Subset of the dataset at specified indices
- __len__()
数据集中的示例数量。