WN18数据集

class dgl.data.WN18Dataset(reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

基础类: KnowledgeGraphDataset

WN18 链接预测数据集。

WN18数据集在翻译嵌入以建模多关系数据中被引入。 它包含了从WordNet中抓取的完整18种关系,大约涉及41,000个同义词集。在创建数据集时,默认情况下会为每条边创建一个具有反向关系类型的反向边。

WN18 数据集统计:

  • 节点数:40943

  • 关系类型的数量:18

  • 反向关系类型的数量:18

  • 标签分割:

    • 训练: 141442

    • 有效: 5000

    • 测试: 5000

Parameters:
  • reverse (bool) – 是否添加反向边。默认为True。

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_nodes

节点数量

Type:

int

num_rels

关系类型的数量

Type:

int

示例

>>> dataset = WN18Dataset()
>>> g = dataset.graph
>>> e_type = g.edata['e_type']
>>>
>>> # get data split
>>> train_mask = g.edata['train_mask']
>>> val_mask = g.edata['val_mask']
>>>
>>> train_set = th.arange(g.num_edges())[train_mask]
>>> val_set = th.arange(g.num_edges())[val_mask]
>>>
>>> # build train_g
>>> train_edges = train_set
>>> train_g = g.edge_subgraph(train_edges,
                              relabel_nodes=False)
>>> train_g.edata['e_type'] = e_type[train_edges];
>>>
>>> # build val_g
>>> val_edges = th.cat([train_edges, val_edges])
>>> val_g = g.edge_subgraph(val_edges,
                            relabel_nodes=False)
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
>>>
__getitem__(idx)[source]

获取图形对象

Parameters:

idx (int) – 项目索引,WN18Dataset只有一个图对象

Returns:

图中包含

  • edata['e_type']: 边关系类型

  • edata['train_edge_mask']: 正训练边掩码

  • edata['val_edge_mask']: 正验证边掩码

  • edata['test_edge_mask']: 正测试边掩码

  • edata['train_mask']: 训练边集掩码(包括反向训练边)

  • edata['val_mask']: 验证边集掩码(包括反向验证边)

  • edata['test_mask']: 测试边集掩码(包括反向测试边)

  • ndata['ntype']: 节点类型。在此数据集中全部为0

Return type:

dgl.DGLGraph

__len__()[source]

数据集中图的数量。