PATTERNDataset

class dgl.data.PATTERNDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLBuiltinDataset

用于图模式识别任务的PATTERN数据集。

每个图G包含5个社区,社区大小在[5, 35]之间随机选择。 每个社区的SBM为p = 0.5,q = 0.35,G上的节点特征 使用大小为3的词汇表(即{0, 1, 2})的均匀随机分布生成。 然后随机生成100个由20个节点组成的模式P,内部概率\(p_P\) = 0.5 和外部概率\(q_P\) = 0.5(即P中50%的节点连接到G)。P的节点特征 也生成为值为{0, 1, 2}的随机信号。图的大小为 44-188个节点。如果节点属于P,则输出节点标签的值为1,如果节点在G中,则值为0。

参考 https://arxiv.org/pdf/2003.00982.pdf

统计:

  • 训练示例:10,000

  • 有效示例:2,000

  • 测试示例:2,000

  • 每个节点的类别数:2

Parameters:
  • mode (str) – 必须是以下之一('train', 'valid', 'test')。 默认值:'train'

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – 是否重新加载数据集。 默认值:False

  • verbose (bool) – 是否打印进度信息。 默认值:False

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

每个节点的类别数量。

Type:

int

示例

>>> from dgl.data import PATTERNDataset
>>> data = PATTERNDataset(mode='train')
>>> data.num_classes
2
>>> len(trainset)
10000
>>> data[0]
Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)}
edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
__getitem__(idx)[source]

获取第 idx 个样本。

Parameters:

idx (int) – The sample index.

Returns:

图结构、节点特征、节点标签和边特征。

  • ndata['feat']: 节点特征

  • ndata['label']: 节点标签

  • edata['feat']: 边特征

Return type:

dgl.DGLGraph

__len__()[source]

数据集中的示例数量。