CiteseerGraphDataset
- class dgl.data.CiteseerGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None, reorder=False)[source]
Bases:
CitationGraphDataset
Citeseer 引用网络数据集。
节点代表科学出版物,边代表引用关系。每个节点都有一个预定义的3703维特征。该数据集设计用于节点分类任务。任务是预测特定出版物的类别。
统计:
节点数:3327
边数: 9228
班级数量:6
标签分割:
训练:120
有效: 500
测试: 1000
- Parameters:
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
reverse_edge (bool) – Whether to add reverse edges in graph. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.reorder (bool) – Whether to reorder the graph using
reorder_graph()
. Default: False.
注释
节点特征是行归一化的。
在citeseer数据集中,图中存在一些孤立的节点。 这些孤立的节点被作为零向量添加到正确的位置。
示例
>>> dataset = CiteseerGraphDataset() >>> g = dataset[0] >>> num_class = dataset.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label']
- __getitem__(idx)[source]
获取图形对象
- Parameters:
idx (int) – 项目索引,CiteseerGraphDataset 只有一个图对象
- Returns:
graph structure, node features and labels.
ndata['train_mask']
: mask for training node setndata['val_mask']
: mask for validation node setndata['test_mask']
: mask for test node setndata['feat']
: node featurendata['label']
: ground truth labels
- Return type: