MiniGCDataset

class dgl.data.MiniGCDataset(num_graphs, min_num_v, max_num_v, seed=0, save_graph=True, force_reload=False, verbose=False, transform=None)[source]

Bases: DGLDataset

合成图分类数据集类。

数据集包含8种不同类型的图表。

  • 类别 0 : 循环图

  • 类别 1 : 星形图

  • 类别 2:轮图

  • 第三类:棒棒糖图

  • 第4类:超立方体图

  • 第5类:网格图

  • 第6课:团图

  • 第7类:环形梯图

Parameters:
  • num_graphs (int) – 此数据集中的图的数量。

  • min_num_v (int) – 图的最小节点数

  • max_num_v (int) – 图的最大节点数

  • seed (int, 默认值为 0) – 用于数据生成的随机种子

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_graphs

图形数量

Type:

int

min_num_v

最小节点数

Type:

int

max_num_v

节点的最大数量

Type:

int

num_classes

类的数量

Type:

int

示例

>>> data = MiniGCDataset(100, 16, 32, seed=0)

数据集实例是可迭代的

>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
      ndata_schemes={}
      edata_schemes={})
>>> label
tensor(5)

将图和标签分批用于小批量训练

>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
      ndata_schemes={}
      edata_schemes={})
__getitem__(idx)[source]

获取第 idx 个样本。

Parameters:

idx (int) – The sample index.

Returns:

图表及其标签。

Return type:

(dgl.Graph, Tensor)

__len__()[source]

返回数据集中图的数量。