dgl.node_subgraph

dgl.node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None)[source]

返回在给定节点上诱导的子图。

节点诱导子图是指边的两个端点都在指定节点集中的图。除了提取子图外,DGL还会将提取的节点和边的特征复制到结果图中。这种复制是惰性的,只有在需要时才会进行数据移动。

如果图是异质的,DGL会为每个关系提取一个子图并将它们组合成结果图。因此,结果图具有与输入图相同的关系集。

Parameters:
  • graph (DGLGraph) – The graph to extract subgraphs from.

  • nodes (nodesdict[str, nodes]) –

    用于形成子图的节点,不能有任何重复值。否则结果将是未定义的。允许的节点格式有:

    • Int Tensor:每个元素是一个节点ID。张量必须与图的设备类型和ID数据类型相同。

    • iterable[int]:每个元素是一个节点ID。

    • Bool Tensor:每个\(i^{th}\)元素是一个布尔标志,表示节点\(i\)是否在子图中。

    如果图是同质的,可以直接传递上述格式。否则,参数必须是一个字典,键为节点类型,值为上述格式的节点ID。

  • relabel_nodes (bool, optional) – 如果为True,提取的子图将只包含指定节点集中的节点,并且会按顺序重新标记节点。

  • store_ids (bool, optional) – 如果为True,它将在生成的图的edata中以dgl.EID的名称存储提取的边的原始ID;如果relabel_nodesTrue,它还会在生成的图的ndata中以dgl.NID的名称存储指定节点的原始ID。

  • output_device (Framework-specific device context object, optional) – The output device. Default is the same as the input graph.

Returns:

G – The subgraph.

Return type:

DGLGraph

注释

This function discards the batch information. Please use dgl.DGLGraph.set_batch_num_nodes() and dgl.DGLGraph.set_batch_num_edges() on the transformed graph to maintain the information.

示例

以下示例使用PyTorch后端。

>>> import dgl
>>> import torch

从同质图中提取一个子图。

>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle
>>> sg = dgl.node_subgraph(g, [0, 1, 4])
>>> sg
Graph(num_nodes=3, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([0, 2]), tensor([1, 0]))
>>> sg.ndata[dgl.NID]  # original node IDs
tensor([0, 1, 4])
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([0, 4])

使用布尔掩码指定节点。

>>> nodes = torch.tensor([True, True, False, False, True])  # choose nodes [0, 1, 4]
>>> dgl.node_subgraph(g, nodes)
Graph(num_nodes=3, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})

生成的子图也会从父图中复制特征。

>>> g.ndata['x'] = torch.arange(10).view(5, 2)
>>> sg = dgl.node_subgraph(g, [0, 1, 4])
>>> sg
Graph(num_nodes=3, num_edges=2,
      ndata_schemes={'x': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.ndata['x']
tensor([[0, 1],
        [2, 3],
        [8, 9]])

从异质图中提取一个子图。

>>> g = dgl.heterograph({
>>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
>>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])
>>> })
>>> sub_g = dgl.node_subgraph(g, {'user': [1, 2]})
>>> sub_g
Graph(num_nodes={'game': 0, 'user': 2},
      num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 0},
      metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])

另请参阅

edge_subgraph