WikiCS数据集
- class dgl.data.WikiCSDataset(raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
Wiki-CS 是一个基于维基百科的节点分类数据集,来自 Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks
该数据集由与计算机科学文章对应的节点组成,边基于超链接,并有10个类别代表该领域的不同分支。
WikiCS 数据集统计:
节点数:11,701
边数:431,726(请注意,原始数据集有216,123条边,但DGL添加了反向边并移除了重复边,因此数量不同)
班级数量:10
节点特征大小:300
不同的训练、验证、停止分割的数量:20
测试分割数量:1
- Parameters:
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: False
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> from dgl.data import WikiCSDataset >>> dataset = WikiCSDataset() >>> dataset.num_classes 10 >>> g = dataset[0] >>> # get node feature >>> feat = g.ndata['feat'] >>> # get node labels >>> labels = g.ndata['label'] >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> stopping_mask = g.ndata['stopping_mask'] >>> test_mask = g.ndata['test_mask'] >>> # The shape of train, val and stopping masks are (num_nodes, num_splits). >>> # The num_splits is the number of different train, validation, stopping splits. >>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,). >>> print(train_mask.shape, val_mask.shape, stopping_mask.shape) (11701, 20) (11701, 20) (11701, 20) >>> print(test_mask.shape) (11701,)