Yelp数据集

class dgl.data.YelpDataset(raw_dir=None, force_reload=False, verbose=False, transform=None, reorder=False)[source]

Bases: DGLBuiltinDataset

用于节点分类的Yelp数据集来自GraphSAINT: 基于图采样的归纳学习方法

该数据集的任务是基于客户评论和友谊对业务类型进行分类。

Yelp 数据集统计:

  • 节点数:716,847

  • 边数: 13,954,819

  • 类别数量:100(多类别)

  • 节点特征大小:300

Parameters:
  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: False

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

  • reorder (bool) – 是否使用reorder_graph()重新排序图。 默认值:False。

num_classes

节点类的数量

Type:

int

示例

>>> dataset = YelpDataset()
>>> dataset.num_classes
100
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
__getitem__(idx)[source]

获取图形对象

Parameters:

idx (int) – 项目索引,FlickrDataset 只有一个图对象

Returns:

图表包含:

  • ndata['label']: 节点标签

  • ndata['feat']: 节点特征

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['val_mask']: 验证节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

Return type:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。