WN18数据集
- class dgl.data.WN18Dataset(reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]
基础类:
KnowledgeGraphDataset
WN18 链接预测数据集。
WN18数据集在翻译嵌入以建模多关系数据中被引入。 它包含了从WordNet中抓取的完整18种关系,大约涉及41,000个同义词集。在创建数据集时,默认情况下会为每条边创建一个具有反向关系类型的反向边。
WN18 数据集统计:
节点数:40943
关系类型的数量:18
反向关系类型的数量:18
标签分割:
训练: 141442
有效: 5000
测试: 5000
- Parameters:
reverse (bool) – 是否添加反向边。默认为True。
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> dataset = WN18Dataset() >>> g = dataset.graph >>> e_type = g.edata['e_type'] >>> >>> # get data split >>> train_mask = g.edata['train_mask'] >>> val_mask = g.edata['val_mask'] >>> >>> train_set = th.arange(g.num_edges())[train_mask] >>> val_set = th.arange(g.num_edges())[val_mask] >>> >>> # build train_g >>> train_edges = train_set >>> train_g = g.edge_subgraph(train_edges, relabel_nodes=False) >>> train_g.edata['e_type'] = e_type[train_edges]; >>> >>> # build val_g >>> val_edges = th.cat([train_edges, val_edges]) >>> val_g = g.edge_subgraph(val_edges, relabel_nodes=False) >>> val_g.edata['e_type'] = e_type[val_edges]; >>> >>> # Train, Validation and Test >>>
- __getitem__(idx)[source]
获取图形对象
- Parameters:
idx (int) – 项目索引,WN18Dataset只有一个图对象
- Returns:
图中包含
edata['e_type']
: 边关系类型edata['train_edge_mask']
: 正训练边掩码edata['val_edge_mask']
: 正验证边掩码edata['test_edge_mask']
: 正测试边掩码edata['train_mask']
: 训练边集掩码(包括反向训练边)edata['val_mask']
: 验证边集掩码(包括反向验证边)edata['test_mask']
: 测试边集掩码(包括反向测试边)ndata['ntype']
: 节点类型。在此数据集中全部为0
- Return type: