Squirrel数据集

class dgl.data.SquirrelDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: GeomGCNDataset

关于松鼠的维基百科页面网络来自多尺度属性节点嵌入,后来由Geom-GCN: 几何图卷积网络修改

节点代表来自英文维基百科的文章,边反映它们之间的相互链接。节点特征表示文章中特定名词的存在。根据它们的平均月流量,节点被分类为5个类别。

统计:

  • 节点数:5201

  • 边数: 217073

  • 班级数量:5

  • 10 个训练/验证/测试分割

    • 训练: 2496

    • 值: 1664

    • 测试: 1041

Parameters:
  • raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/

  • force_reload (bool, optional) – Whether to re-download the data source. Default: False

  • verbose (bool, optional) – Whether to print progress information. Default: True

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access. Default: None

num_classes

节点类的数量

Type:

int

注释

该图不包含双向的边。

示例

>>> from dgl.data import SquirrelDataset
>>> dataset = SquirrelDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

获取索引处的数据对象。

__len__()

数据集中的示例数量。