dgl.dataloading.as_edge_prediction_sampler

dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, prefetch_labels=None)[source]

从节点采样器创建一个边采样器。

对于每批边,采样器将提供的节点采样器应用于它们的源节点和目标节点以提取子图。如果提供了负采样器,它还会生成负边,并为它们的事件节点提取子图。

每次迭代,采样器将产生

  • 用于计算边上表示的输入节点张量,或节点类型名称和此类张量的字典。

  • 一个仅包含小批量中的边及其相关节点的子图。 请注意,该子图与原始图具有相同的元图结构。

  • 如果提供了负采样器,另一个包含“负边”的图将被生成,这些负边连接由给定负采样器产生的源节点和目标节点。

  • 由提供的节点采样器返回的子图或MFGs,从小批量中的边的入射节点(以及负边的入射节点,如果适用)生成。

Parameters:
  • sampler (Sampler) – 节点采样器对象。它还需要sample方法必须有一个可选的第三个参数exclude_eids,表示要从邻域中排除的边ID。该参数将是一个张量(对于同构图)或一个边类型和张量的字典(对于异构图)。

  • exclude (Union[str, callable], optional) –

    是否以及如何排除与小批量中采样边相关的依赖关系。可能的值为

    • None,表示不排除任何边。

    • self,表示排除当前小批量中的边。

    • reverse_id,表示不仅排除当前小批量中的边,还根据参数reverse_eids中的ID映射排除它们的反向边。

    • reverse_types,表示不仅排除当前小批量中的边,还根据参数reverse_etypes排除存储在另一种类型中的它们的反向边。

    • 用户定义的排除规则。它是一个可调用对象,以当前小批量中的边作为单一参数,并应返回要排除的边。

  • reverse_eids (Tensordict[etype, Tensor], optional) –

    一个反向边ID映射的张量。第i个元素表示第i条边的反向边的ID。

    如果图是异质的,这个参数需要一个边类型和反向边ID映射张量的字典。

  • reverse_etypes (dict[etype, etype], optional) – 从原始边类型到其反向边类型的映射。

  • negative_sampler (callable, optional) – 负采样器。

  • prefetch_labels (list[str] or dict[etype, list[str]], optional) –

    要为返回的正对图预取的边标签。

    有关预取的详细解释,请参见 guide-minibatch-prefetching

示例

以下示例展示了如何在同质无向图上为一组边 train_eid 训练一个3层GNN进行边分类。每个节点从所有邻居接收消息。

给定一个源节点ID数组 src 和另一个目标节点ID数组 dst,以下代码创建了一个双向图:

>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

图中边 \(i\) 的反向边是边 \(i + |E|\)。因此,我们可以通过以下方式创建一个反向边映射 reverse_eids

>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

通过将reverse_eids传递给边采样器,当前小批量中的边及其反向边将从提取的子图中排除,以避免信息泄露。

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_id', reverse_eids=reverse_eids)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

对于链接预测,可以提供一个负采样器来采样负边。 下面的代码使用DGL的Uniform 为每条边生成5个负样本:

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     sampler, exclude='reverse_id', reverse_eids=reverse_eids,
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)

对于异构图,反向边可能属于不同的关系。例如,下图中的关系“用户点击项目”和“项目被用户点击”是相互反向的。

>>> g = dgl.heterograph({
...     ('user', 'click', 'item'): (user, item),
...     ('item', 'clicked-by', 'user'): (item, user)})

为了正确地从每个小批量中排除边,设置 exclude='reverse_types' 并将字典 {'click': 'clicked-by', 'clicked-by': 'click'} 传递给 reverse_etypes 参数。

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})
>>> dataloader = dgl.dataloading.DataLoader(
...     g, {'click': train_eid}, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

对于链接预测,提供一个负采样器来生成负样本:

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)