PubmedGraphDataset
- class dgl.data.PubmedGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[source]
Bases:
CitationGraphDataset
Pubmed引文网络数据集。
节点代表科学出版物,边代表引用关系。每个节点都有一个预定义的500维特征。该数据集是为节点分类任务设计的。任务是预测某些出版物的类别。
统计:
节点数:19717
边数: 88651
班级数量:3
标签分割:
训练集: 60
有效: 500
测试: 1000
- Parameters:
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
reverse_edge (bool) – Whether to add reverse edges in graph. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.reorder (bool) – Whether to reorder the graph using
reorder_graph()
. Default: False.
注释
节点特征是行归一化的。
示例
>>> dataset = PubmedGraphDataset() >>> g = dataset[0] >>> num_class = dataset.num_of_class >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label']
- __getitem__(idx)[source]
获取图形对象
- Parameters:
idx (int) – 项目索引,PubmedGraphDataset 只有一个图对象
- Returns:
graph structure, node features and labels.
ndata['train_mask']
: mask for training node setndata['val_mask']
: mask for validation node setndata['test_mask']
: mask for test node setndata['feat']
: node featurendata['label']
: ground truth labels
- Return type: