PATTERNDataset
- class dgl.data.PATTERNDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]
Bases:
DGLBuiltinDataset
用于图模式识别任务的PATTERN数据集。
每个图G包含5个社区,社区大小在[5, 35]之间随机选择。 每个社区的SBM为p = 0.5,q = 0.35,G上的节点特征 使用大小为3的词汇表(即{0, 1, 2})的均匀随机分布生成。 然后随机生成100个由20个节点组成的模式P,内部概率\(p_P\) = 0.5 和外部概率\(q_P\) = 0.5(即P中50%的节点连接到G)。P的节点特征 也生成为值为{0, 1, 2}的随机信号。图的大小为 44-188个节点。如果节点属于P,则输出节点标签的值为1,如果节点在G中,则值为0。
参考 https://arxiv.org/pdf/2003.00982.pdf
统计:
训练示例:10,000
有效示例:2,000
测试示例:2,000
每个节点的类别数:2
- Parameters:
mode (str) – 必须是以下之一('train', 'valid', 'test')。 默认值:'train'
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – 是否重新加载数据集。 默认值:False
verbose (bool) – 是否打印进度信息。 默认值:False
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
示例
>>> from dgl.data import PATTERNDataset >>> data = PATTERNDataset(mode='train') >>> data.num_classes 2 >>> len(trainset) 10000 >>> data[0] Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})