Flickr数据集
- class dgl.data.FlickrDataset(raw_dir=None, force_reload=False, verbose=False, transform=None, reorder=False)[source]
Bases:
DGLBuiltinDataset
用于节点分类的Flickr数据集来自GraphSAINT: 基于图采样的归纳学习方法
该数据集的任务是根据在线图像的描述和常见属性对图像类型进行分类。
Flickr 数据集统计:
节点数:89,250
边数:899,756
类别数量:7
节点特征大小:500
- Parameters:
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: False
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.reorder (bool) – Whether to reorder the graph using
reorder_graph()
. Default: False.
示例
>>> from dgl.data import FlickrDataset >>> dataset = FlickrDataset() >>> dataset.num_classes 7 >>> g = dataset[0] >>> # get node feature >>> feat = g.ndata['feat'] >>> # get node labels >>> labels = g.ndata['label'] >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask']
- __getitem__(idx)[source]
获取图形对象
- Parameters:
idx (int) – Item index, FlickrDataset has only one graph object
- Returns:
The graph contains:
ndata['label']
: node labelndata['feat']
: node featurendata['train_mask']
: mask for training node setndata['val_mask']
: mask for validation node setndata['test_mask']
: mask for test node set
- Return type: