SAINTSampler

class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[source]

基础类:Sampler

GraphSAINT: 基于图采样的归纳学习方法中随机采样节点/边/路径

对于每次调用,采样器会采样一个节点子集,然后返回一个节点诱导子图。 采样节点子集有三种选项:

  • 对于'node'采样器,采样一个节点的概率与其出度成正比。

  • 'edge' 采样器首先采样一个边子集,然后使用这些边的端点。

  • 'walk' 采样器使用随机游走访问的节点。它均匀地选择一些根节点,然后从每个根节点执行固定长度的随机游走。

Parameters:
  • mode (str) – 使用的采样器,可以是 'node', 'edge', 或 'walk'

  • 预算 (inttuple[int]) –

    采样器配置。

    • 对于 'node' 采样器,预算指定每个采样子图中的节点数量。

    • 对于 'edge' 采样器,预算指定采样的边数以诱导一个子图。

    • 对于 'walk' 采样器,预算是一个元组。budget[0] 指定生成随机游走的根节点数量。budget[1] 指定随机游走的长度。

  • cache (bool, optional) – 如果为False,则不会缓存用于采样的概率数组。如果您希望在不同的图中使用采样器,则需要将其设置为False。

  • prefetch_ndata (list[str], optional) –

    要为子图预取的节点数据。

    有关预取的详细解释,请参见 guide-minibatch-prefetching

  • prefetch_edata (list[str], optional) –

    要为子图预取的边数据。

    有关预取的详细解释,请参见 guide-minibatch-prefetching

  • output_device (device, optional) – 输出子图的设备。

示例

>>> import torch
>>> from dgl.dataloading import SAINTSampler, DataLoader
>>> num_iters = 1000
>>> sampler = SAINTSampler(mode='node', budget=6000)
>>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels
>>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)