SAINTSampler
- class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[source]
基础类:
Sampler
从GraphSAINT: 基于图采样的归纳学习方法中随机采样节点/边/路径
对于每次调用,采样器会采样一个节点子集,然后返回一个节点诱导子图。 采样节点子集有三种选项:
对于
'node'
采样器,采样一个节点的概率与其出度成正比。'edge'
采样器首先采样一个边子集,然后使用这些边的端点。'walk'
采样器使用随机游走访问的节点。它均匀地选择一些根节点,然后从每个根节点执行固定长度的随机游走。
- Parameters:
mode (str) – 使用的采样器,可以是
'node'
,'edge'
, 或'walk'
。采样器配置。
对于
'node'
采样器,预算指定每个采样子图中的节点数量。对于
'edge'
采样器,预算指定采样的边数以诱导一个子图。对于
'walk'
采样器,预算是一个元组。budget[0] 指定生成随机游走的根节点数量。budget[1] 指定随机游走的长度。
cache (bool, optional) – 如果为False,则不会缓存用于采样的概率数组。如果您希望在不同的图中使用采样器,则需要将其设置为False。
prefetch_ndata (list[str], optional) –
要为子图预取的节点数据。
有关预取的详细解释,请参见 guide-minibatch-prefetching。
prefetch_edata (list[str], optional) –
要为子图预取的边数据。
有关预取的详细解释,请参见 guide-minibatch-prefetching。
output_device (device, optional) – 输出子图的设备。
示例
>>> import torch >>> from dgl.dataloading import SAINTSampler, DataLoader >>> num_iters = 1000 >>> sampler = SAINTSampler(mode='node', budget=6000) >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) >>> for subg in dataloader: ... train_on(subg)