dgl.sampling.sample_neighbors

dgl.sampling.sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None, output_device=None)[source]

对给定节点的邻近边进行采样并返回诱导子图。

对于每个节点,将随机选择一定数量的入站(或当edge_dir == 'out'时为出站)边。返回的图将包含原始图中的所有节点,但仅包含采样的边。

Node/edge features are not preserved. The original IDs of the sampled edges are stored as the dgl.EID feature in the returned graph.

此函数支持GPU采样。更多详情请参阅6.8 使用GPU进行邻域采样

Parameters:
  • g (DGLGraph) – 图。可以在CPU或GPU上。

  • nodes (tensordict) –

    从中采样邻居的节点ID。

    此参数可以接受单个ID张量或节点类型和ID张量的字典。 如果给出单个张量,则图必须只有一种类型的节点。

  • fanout (intdict[etype, int]) –

    每种边类型上每个节点要采样的边数。

    此参数可以接受一个整数或一个边类型和整数的字典。 如果给定一个整数,DGL将为每种边类型的每个节点采样此数量的边。

    如果为单个边类型给定-1,将选择该边类型且概率不为零的所有相邻边。

  • edge_dir (str, optional) –

    Determines whether to sample inbound or outbound edges.

    Can take either in for inbound edges or out for outbound edges.

  • prob (str, optional) –

    用作节点每个相邻边的(未归一化)概率的特征名称。该特征必须为每条边只有一个元素。

    特征必须是非负浮点数或布尔值。否则,结果将是未定义的。

  • exclude_edges (tensordict) –

    在为种子节点采样邻居时要排除的边ID。

    此参数可以接受单个ID张量或边类型和ID张量的字典。 如果给出单个张量,则图必须只有一种类型的节点。

  • replace (bool, optional) – 如果为True,则进行有放回的抽样。

  • copy_ndata (bool, optional) –

    如果为True,新图的节点特征将从原始图中复制。如果为False,新图将不会有任何节点特征。

    (默认值: True)

  • copy_edata (bool, optional) –

    如果为True,新图的边特征将从原始图中复制。如果为False,新图将不会有任何边特征。

    (默认值: True)

  • _dist_training (bool, optional) –

    内部参数。请勿使用。

    (默认值: False)

  • output_device (Framework-specific device context object, optional) – The output device. Default is the same as the input graph.

Returns:

一个仅包含采样邻边的采样子图。

Return type:

DGLGraph

注释

如果 copy_ndatacopy_edata 为 True,则相同的张量将用作原始图和新图的节点或边特征。因此,用户应避免对新图的节点特征进行原地操作,以避免特征损坏。

示例

假设你有以下图表

>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))

以及权重

>>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])

为节点0和节点1采样一个入边:

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1)
>>> sg.edges(order='eid')
(tensor([1, 0]), tensor([0, 1]))
>>> sg.edata[dgl.EID]
tensor([2, 0])

为节点0和节点1采样一个入边,概率在边特征 prob中:

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1, prob='prob')
>>> sg.edges(order='eid')
(tensor([2, 1]), tensor([0, 1]))

fanout大于实际邻居数量且不进行替换时, DGL将选择所有邻居:

>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3)
>>> sg.edges(order='eid')
(tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))

在种子节点的采样过程中排除某些EID:

>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
>>> g_edges = g.all_edges(form='all')``
(tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2])
>>> sg.all_edges(form='all')
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3])
tensor([False, False, False])
>>> g = dgl.heterograph({
...   ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]),
...   ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]),
...   ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])})
>>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> excluded_edges  = {('drug', 'interacts', 'drug'): g_edges[2][:3]}
>>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges)
>>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug'))
tensor([False, False, False])