dgl.in_subgraph
- dgl.in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None)[source]
返回由给定节点的所有边类型的入边所诱导的子图。
子图中的入边相当于使用给定节点的入边创建一个新图。除了提取子图外,DGL还将提取的节点和边的特征复制到结果图中。这种复制是惰性的,只有在需要时才会进行数据移动。
如果图是异质的,DGL会为每个关系提取一个子图并将它们组合成结果图。因此,结果图具有与输入图相同的关系集。
- Parameters:
graph (DGLGraph) – The input graph.
nodes (nodes or dict[str, nodes]) –
The nodes to form the subgraph, which cannot have any duplicate value. The result will be undefined otherwise. The allowed nodes formats are:
Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.
iterable[int]: Each element is a node ID.
If the graph is homogeneous, one can directly pass the above formats. Otherwise, the argument must be a dictionary with keys being node types and values being the node IDs in the above formats.
relabel_nodes (bool, optional) – If True, it will remove the isolated nodes and relabel the rest nodes in the extracted subgraph.
store_ids (bool, optional) – If True, it will store the raw IDs of the extracted edges in the
edata
of the resulting graph under namedgl.EID
; ifrelabel_nodes
isTrue
, it will also store the raw IDs of the extracted nodes in thendata
of the resulting graph under namedgl.NID
.output_device (Framework-specific device context object, optional) – The output device. Default is the same as the input graph.
- Returns:
子图。
- Return type:
注释
This function discards the batch information. Please use
dgl.DGLGraph.set_batch_num_nodes()
anddgl.DGLGraph.set_batch_num_edges()
on the transformed graph to maintain the information.示例
以下示例使用PyTorch后端。
>>> import dgl >>> import torch
从同质图中提取一个子图。
>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])) # 5-node cycle >>> g.edata['w'] = torch.arange(10).view(5, 2) >>> sg = dgl.in_subgraph(g, [2, 0]) >>> sg Graph(num_nodes=5, num_edges=2, ndata_schemes={} edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}) >>> sg.edges() (tensor([1, 4]), tensor([2, 0])) >>> sg.edata[dgl.EID] # original edge IDs tensor([1, 4]) >>> sg.edata['w'] # also extract the features tensor([[2, 3], [8, 9]])
提取带有节点标记的子图。
>>> sg = dgl.in_subgraph(g, [2, 0], relabel_nodes=True) >>> sg Graph(num_nodes=4, num_edges=2, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64} edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}) >>> sg.edges() (tensor([1, 3]), tensor([2, 0])) >>> sg.edata[dgl.EID] # original edge IDs tensor([1, 4]) >>> sg.ndata[dgl.NID] # original node IDs tensor([0, 1, 2, 4])
从异质图中提取一个子图。
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), ... ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])}) >>> sub_g = g.in_subgraph({'user': [2], 'game': [2]}) >>> sub_g Graph(num_nodes={'game': 3, 'user': 3}, num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2}, metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])
另请参阅