NeighborSampler

class dgl.dataloading.NeighborSampler(fanouts, edge_dir='in', prob=None, mask=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None, fused=True)[source]

Bases: BlockSampler

通过多层GNN的邻居采样构建节点表示的计算依赖关系的采样器。

此采样器将使每个节点从每种边类型的固定数量的邻居中收集消息。邻居是均匀选择的。

Parameters:
  • fanouts (list[int] or list[dict[etype, int]]) –

    每个GNN层中每种边类型要采样的邻居列表,第i个元素是第i个GNN层的fanout。

    如果只提供一个整数,DGL会假设每种边类型都有相同的fanout。

    如果某一层的某一边类型提供了-1,则该边类型的所有入边都将被包含。

  • edge_dir (str, 默认 'in') – 可以是 'in' `` 其中邻居将根据传入的边进行采样,或者 ``'out' 否则,与 dgl.sampling.sample_neighbors() 相同。

  • prob (str, optional) –

    如果给定,每个邻居被采样的概率与g.edata中给定名称的边特征值成比例。该特征必须是每条边上的标量。

    此参数与mask互斥。如果您想同时指定掩码和概率,请考虑将概率与掩码相乘。

  • mask (str, optional) –

    如果给定,只有当g.edata中具有给定名称的边掩码为True时,才能选择邻居。每条边的数据必须是布尔类型。

    此参数与prob互斥。如果你想同时指定掩码和概率,考虑将概率与掩码相乘。

  • replace (bool, default False) – 是否进行有放回的抽样

  • prefetch_node_feats (list[str] or dict[ntype, list[str]], optional) – The source node data to prefetch for the first MFG, corresponding to the input node features necessary for the first GNN layer.

  • prefetch_labels (list[str] or dict[ntype, list[str]], optional) – The destination node data to prefetch for the last MFG, corresponding to the node labels of the minibatch.

  • prefetch_edge_feats (list[str] or dict[etype, list[str]], optional) – The edge data names to prefetch for all the MFGs, corresponding to the edge features necessary for all GNN layers.

  • output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes.

  • fused (bool, 默认值 True) – 如果为True且设备是CPU,则调用融合的样本邻居。此版本要求seed_nodes是唯一的

示例

节点分类

为了在一组节点train_nid上训练一个3层GNN进行节点分类,在一个同质图上,每个节点分别从5、10、15个邻居接收消息,分别对应第一层、第二层和第三层(假设后端是PyTorch):

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)

如果在异质图上进行训练,并且您希望每种边类型有不同的邻居数量,则应提供一个字典列表。每个字典将指定每种边类型要选择的邻居数量。

>>> sampler = dgl.dataloading.NeighborSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)

如果您希望进行非均匀邻居采样:

>>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')

或者在边缘掩码上进行采样:

>>> g.edata['mask'] = torch.rand(g.num_edges()) < 0.2   # any 1D boolean mask works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='mask')

边分类和链接预测

这个类也可以与as_edge_prediction_sampler()一起用于边缘分类和链接预测。

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)

See the documentation as_edge_prediction_sampler() for more details.

注释

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.