TreeCycleDataset
- class dgl.data.TreeCycleDataset(tree_height=8, num_motifs=60, cycle_size=6, perturb_ratio=0.01, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]
Bases:
DGLBuiltinDataset
来自GNNExplainer: 生成图神经网络的解释的TREE-CYCLES数据集
这是一个用于节点分类的合成数据集。它是通过按顺序执行以下步骤生成的。
构建一个平衡二叉树作为基础图。
构建一组循环主题。
将图案附加到基础图的随机选择的节点上。
通过添加随机边来扰动图。
为所有节点生成常量特征,其值为1。
树中的节点属于类别0,循环中的节点属于类别1。
- Parameters:
tree_height (int, optional) – 平衡二叉树的高度。默认值:8
num_motifs (int, optional) – 使用的循环主题数量。默认值:60
cycle_size (int, 可选) – 循环模式中的节点数。默认值:6
perturb_ratio (float, optional) – 在扰动中添加的随机边数除以图中原始边数的比例。默认值:0.01
seed (integer, random_state, or None, optional) – Indicator of random number generation state. Default: None
raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/
force_reload (bool, optional) – Whether to always generate the data from scratch rather than load a cached version. Default: False
verbose (bool, optional) – Whether to print progress information. Default: True
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access. Default: None
示例
>>> from dgl.data import TreeCycleDataset >>> dataset = TreeCycleDataset() >>> dataset.num_classes 2 >>> g = dataset[0] >>> label = g.ndata['label'] >>> feat = g.ndata['feat']