AIFBDataset

class dgl.data.AIFBDataset(print_every=10000, insert_reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

基础类:RDFGraphDataset

用于节点分类任务的AIFB数据集

AIFB数据集是一个语义网(RDF)数据集,用作数据挖掘的基准。它记录了卡尔斯鲁厄大学AIFB的组织结构。

AIFB数据集统计:

  • 节点数:7262

  • 边数:48810(包括反向边)

  • 目标类别:人员

  • 班级数量:4

  • 标签分割:

    • 训练:140

    • 测试: 36

Parameters:
  • print_every (int) – 每X个元组预处理日志。默认值:10000。

  • insert_reverse (bool) – 如果为真,将反向边和反向关系添加到最终图中。默认值:True。

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

预测的类别数量

Type:

int

predict_category

具有预测标签的实体类别(节点类型)

Type:

str

示例

>>> dataset = dgl.data.rdf.AIFBDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
__getitem__(idx)[source]

获取图形对象

Parameters:

idx (int) – 项目索引,AIFBDataset 只有一个图对象

Returns:

图表包含:

  • ndata['train_mask']: 训练节点集的掩码

  • ndata['test_mask']: 测试节点集的掩码

  • ndata['label']: 节点标签

Return type:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。

Return type:

int