SBMMixtureDataset
- class dgl.data.SBMMixtureDataset(n_graphs, n_nodes, n_communities, k=2, avg_deg=3, pq='Appendix_C', rng=None)[source]
Bases:
DGLDataset
对称随机块模型混合
参考:使用层次图神经网络进行监督社区检测的附录C
- Parameters:
n_graphs (int) – 图的数量。
n_nodes (int) – 节点数量。
n_communities (int) – 社区数量。
k (int, optional) – 乘数。默认值:2
avg_deg (int, optional) – 平均度数。默认值:3
pq (list of pair of nonnegative float or str, optional) – 随机密度。此参数用于未来的扩展, 目前它始终使用默认值。 默认值:Appendix_C
rng (numpy.random.RandomState, 可选) – 随机数生成器。如果未提供,则使用 seed=None 的 numpy.random.RandomState(), 如果可用,则从 /dev/urandom(或 Windows 的等效文件)读取数据,否则从时钟中获取种子。 默认值:None
- Raises:
如果pq不是列表或字符串,则会引发RuntimeError。 –
示例
>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2) >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn) >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader: ... # your code here
- __getitem__(idx)[source]
通过索引获取一个示例
- Parameters:
idx (int) – Item index
- Returns:
graph (
dgl.DGLGraph
) – 原始图line_graph (
dgl.DGLGraph
) – graph的线图graph_degree (numpy.ndarray) – graph中每个节点的入度
line_graph_degree (numpy.ndarray) – line_graph中每个节点的入度
pm_pd (numpy.ndarray) – 边指示矩阵Pm和Pd
- collate_fn(x)[source]
collate 函数用于 dataloader
- Parameters:
x (tuple) –
包含以下内容的一批数据:
- graph:
dgl.DGLGraph
原始图
- graph:
- line_graph:
dgl.DGLGraph
graph的线图
- line_graph:
- graph_degree: numpy.ndarray
graph中每个节点的入度
- line_graph_degree: numpy.ndarray
line_graph中每个节点的入度
- pm_pd: numpy.ndarray
边指示矩阵Pm和Pd
- Returns:
g_batch (
dgl.DGLGraph
) – 批处理图lg_batch (
dgl.DGLGraph
) – 批处理线图degg_batch (numpy.ndarray) – g_batch 中每个节点的入度批处理
deglg_batch (numpy.ndarray) – lg_batch 中每个节点的入度批处理
pm_pd_batch (numpy.ndarray) – 边指示矩阵 Pm 和 Pd 的批处理