ClusterGCNSampler

class dgl.dataloading.ClusterGCNSampler(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]

Bases: Sampler

Cluster-GCN: 一种用于训练深度和大规模图卷积网络的高效算法中提取的集群采样器

此采样器首先使用METIS分区对图进行分区,然后将每个分区的节点缓存到给定缓存目录中的文件中。

采样器然后根据提供的分区ID选择图分区,取这些分区中所有节点的并集,并在其sample方法中返回一个诱导子图。

Parameters:
  • g (DGLGraph) – 原始图。必须是同质的并且在CPU上。

  • k (int) – The number of partitions.

  • cache_path (str) – 用于存储分区结果的缓存目录路径。

  • balance_ntypes – 传递给 dgl.metis_partition_assignment()

  • balkance_edges – 传递给 dgl.metis_partition_assignment()

  • mode – 传递给 dgl.metis_partition_assignment()

  • prefetch_ndata (list[str], optional) –

    The node data to prefetch for the subgraph.

    See guide-minibatch-prefetching for a detailed explanation of prefetching.

  • prefetch_edata (list[str], optional) –

    The edge data to prefetch for the subgraph.

    See guide-minibatch-prefetching for a detailed explanation of prefetching.

  • output_device (device, optional) – 输出子图或MFGs的设备。默认与分区索引的小批量相同。

示例

节点分类

使用此采样器,数据加载器将接受分区ID列表作为迭代的索引。例如,以下代码首先使用METIS将图分成1000个分区,在每次迭代时,它获取由20个随机选择的分区覆盖的节点所诱导的子图。

>>> num_parts = 1000
>>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, torch.arange(num_parts), sampler,
...     batch_size=20, shuffle=True, drop_last=False, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)