CoraGraphDataset

class dgl.data.CoraGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[source]

基础类:CitationGraphDataset

Cora引用网络数据集。

节点代表论文,边代表引用关系。每个节点都有一个预定义的1433维特征。该数据集是为节点分类任务设计的。任务是预测某篇论文的类别。

统计:

  • 节点数:2708

  • 边数: 10556

  • 班级数量:7

  • 标签分割:

    • 训练:140

    • 有效: 500

    • 测试: 1000

Parameters:
  • raw_dir (str) – 用于下载/包含输入数据目录的原始文件目录。 默认值:~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • reverse_edge (bool) – 是否在图中添加反向边。默认值:True。

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

  • reorder (bool) – 是否使用reorder_graph()重新排序图。默认值:False。

num_classes

标签类别数量

Type:

int

注释

节点特征是行归一化的。

示例

>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> num_class = dataset.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)[source]

获取图形对象

Parameters:

idx (int) – 项目索引,CoraGraphDataset只有一个图对象

Returns:

图结构、节点特征和标签。

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['val_mask']: 验证节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

  • ndata['feat']: 节点特征

  • ndata['label']: 真实标签

Return type:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。