树形网格数据集

class dgl.data.TreeGridDataset(tree_height=8, num_motifs=80, grid_size=3, perturb_ratio=0.1, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: DGLBuiltinDataset

来自GNNExplainer: 生成图神经网络的解释的TREE-GRIDS数据集

这是一个用于节点分类的合成数据集。它是通过按顺序执行以下步骤生成的。

  • 构建一个平衡二叉树作为基础图。

  • 构建一组n×n网格图案。

  • 将图案附加到基础图的随机选择的节点上。

  • 通过添加随机边来扰动图。

  • 为所有节点生成常量特征,其值为1。

  • 树中的节点属于类别0,网格中的节点属于类别1。

Parameters:
  • tree_height (int, optional) – Height of the balanced binary tree. Default: 8

  • num_motifs (int, optional) – 使用的网格图案数量。默认值:80

  • grid_size (int, optional) – 网格图案中的节点数将为 grid_size ^ 2。默认值:3

  • perturb_ratio (float, optional) – 扰动中添加的随机边数除以图中原始边数的比例。默认值:0.1

  • seed (integer, random_state, or None, optional) – Indicator of random number generation state. Default: None

  • raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/

  • force_reload (bool, optional) – Whether to always generate the data from scratch rather than load a cached version. Default: False

  • verbose (bool, optional) – Whether to print progress information. Default: True

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access. Default: None

num_classes

节点类的数量

Type:

int

示例

>>> from dgl.data import TreeGridDataset
>>> dataset = TreeGridDataset()
>>> dataset.num_classes
2
>>> g = dataset[0]
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
__getitem__(idx)[source]

获取索引处的数据对象。

__len__()[source]

数据集中的示例数量。