Reddit数据集
- class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
用于社区检测(节点分类)的Reddit数据集
这是一个来自2014年9月Reddit帖子的图数据集。 在这种情况下,节点标签是帖子所属的社区或“subreddit”。 作者抽取了50个大型社区,并构建了一个帖子到帖子的图,如果同一用户对两个帖子都进行了评论,则将这两个帖子连接起来。 该数据集总共包含232,965个帖子,平均度为492。我们使用前20天进行训练,剩余的天数用于测试(其中30%用于验证)。
Reference: http://snap.stanford.edu/graphsage/
统计
节点数:232,965
边数:114,615,892
节点特征大小:602
训练样本数量:153,431
验证样本数量:23,831
测试样本数量:55,703
- Parameters:
self_loop (bool) – 是否加载带有自环连接的数据集。默认值:False
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> data = RedditDataset() >>> g = data[0] >>> num_classes = data.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label'] >>> >>> # Train, Validation and Test