dgl.node_type_subgraph

dgl.node_type_subgraph(graph, ntypes, output_device=None)[source]

返回由给定节点类型诱导的子图。

节点类型诱导的子图包含图中给定节点类型子集的所有节点,以及端点都在此子集中的任何边。除了提取子图外,DGL还将提取的节点和边的特征复制到结果图中。这种复制是惰性的,只有在需要时才会进行数据移动。

Parameters:
  • graph (DGLGraph) – The graph to extract subgraphs from.

  • ntypes (list[str]) – 子图中节点的类型名称。

  • output_device (Framework-specific device context object, optional) – The output device. Default is the same as the input graph.

Returns:

G – The subgraph.

Return type:

DGLGraph

注释

This function discards the batch information. Please use dgl.DGLGraph.set_batch_num_nodes() and dgl.DGLGraph.set_batch_num_edges() on the transformed graph to maintain the information.

示例

以下示例使用PyTorch后端。

>>> import dgl
>>> import torch

实例化一个异构图。

>>> g = dgl.heterograph({
>>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
>>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])
>>> })
>>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])

获取子图。

>>> sub_g = g.node_type_subgraph(['user'])
>>> print(sub_g)
Graph(num_nodes=3, num_edges=3,
      ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={})

获取提取的节点特征。

>>> sub_g.nodes['user'].data['h']
tensor([[0.],
        [1.],
        [2.]])

另请参阅

edge_type_subgraph