SBMMixtureDataset

class dgl.data.SBMMixtureDataset(n_graphs, n_nodes, n_communities, k=2, avg_deg=3, pq='Appendix_C', rng=None)[source]

Bases: DGLDataset

对称随机块模型混合

参考:使用层次图神经网络进行监督社区检测的附录C

Parameters:
  • n_graphs (int) – 图的数量。

  • n_nodes (int) – 节点数量。

  • n_communities (int) – 社区数量。

  • k (int, optional) – 乘数。默认值:2

  • avg_deg (int, optional) – 平均度数。默认值:3

  • pq (list of pair of nonnegative float or str, optional) – 随机密度。此参数用于未来的扩展, 目前它始终使用默认值。 默认值:Appendix_C

  • rng (numpy.random.RandomState, 可选) – 随机数生成器。如果未提供,则使用 seed=None 的 numpy.random.RandomState(), 如果可用,则从 /dev/urandom(或 Windows 的等效文件)读取数据,否则从时钟中获取种子。 默认值:None

Raises:

如果pq不是列表字符串,则会引发RuntimeError。

示例

>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2)
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn)
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
...     # your code here
__getitem__(idx)[source]

通过索引获取一个示例

Parameters:

idx (int) – Item index

Returns:

  • graph (dgl.DGLGraph) – 原始图

  • line_graph (dgl.DGLGraph) – graph的线图

  • graph_degree (numpy.ndarray) – graph中每个节点的入度

  • line_graph_degree (numpy.ndarray) – line_graph中每个节点的入度

  • pm_pd (numpy.ndarray) – 边指示矩阵Pm和Pd

__len__()[source]

数据集中的图表数量。

collate_fn(x)[source]

collate 函数用于 dataloader

Parameters:

x (tuple) –

包含以下内容的一批数据:

  • graph: dgl.DGLGraph

    原始图

  • line_graph: dgl.DGLGraph

    graph的线图

  • graph_degree: numpy.ndarray

    graph中每个节点的入度

  • line_graph_degree: numpy.ndarray

    line_graph中每个节点的入度

  • pm_pd: numpy.ndarray

    边指示矩阵Pm和Pd

Returns:

  • g_batch (dgl.DGLGraph) – 批处理图

  • lg_batch (dgl.DGLGraph) – 批处理线图

  • degg_batch (numpy.ndarray) – g_batch 中每个节点的入度批处理

  • deglg_batch (numpy.ndarray) – lg_batch 中每个节点的入度批处理

  • pm_pd_batch (numpy.ndarray) – 边指示矩阵 Pm 和 Pd 的批处理