dgl.to_simple
- dgl.to_simple(g, return_counts='count', writeback_mapping=False, copy_ndata=True, copy_edata=False, aggregator='arbitrary')[source]
将图转换为没有平行边的简单图并返回。
对于具有多种边类型的异构图,DGL将具有相同边类型和端点的边视为平行边并将其移除。 可选地,可以通过指定
return_counts
参数来获取平行边的数量。要获取输入图中边ID到结果图中边ID的映射,请将writeback_mapping
设置为true。- Parameters:
g (DGLGraph) – 输入图。必须在CPU上。
return_counts (str, optional) –
如果给定,原始图中每条边的计数将作为边特征存储在名称
return_counts
下。同名的旧特征将被替换。(默认值: “count”)
writeback_mapping (bool, optional) –
如果为True,则为每种边类型返回一个额外的写回映射。写回映射是一个记录从输入图中的边ID到结果图中的边ID的映射的张量。如果图是异质的,DGL将返回一个边类型和此类张量的字典。
如果为False,则只返回简单的图。
(默认值: False)
copy_ndata (bool, optional) –
如果为True,简单图的节点特征将从原始图中复制。
如果为False,简单图将不会有任何节点特征。
(默认值: True)
copy_edata (bool, optional) –
如果为True,简单图的边特征将从原始图中复制。如果两个节点(u, v)之间存在重复边,则边的特征是重复边特征的聚合。
如果为False,简单图将没有任何边特征。
(默认值:False)
aggregator (str, optional) –
指示如何合并重复边的边特征。 如果为
arbitrary
,则选择其中一个重复边的特征。 如果为sum
,则计算重复边特征的总和。 如果为mean
,则计算重复边特征的平均值。(默认值:
arbitrary
)
- Returns:
DGLGraph – 图。
tensor or dict of tensor – 写回映射。仅当
writeback_mapping
为 True 时。
注释
If
copy_ndata
is True, the resulting graph will share the node feature tensors with the input graph. Hence, users should try to avoid in-place operations which will be visible to both graphs.This function discards the batch information. Please use
dgl.DGLGraph.set_batch_num_nodes()
anddgl.DGLGraph.set_batch_num_edges()
on the transformed graph to maintain the information.示例
同构图
创建一个图表来演示 to_simple API。 在原始图表中,1 和 2 之间有多条边。
>>> import dgl >>> import torch as th >>> g = dgl.graph((th.tensor([0, 1, 2, 1]), th.tensor([1, 2, 0, 2]))) >>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]]) >>> g.edata['h'] = th.tensor([[3.], [4.], [5.], [6.]])
将图转换为简单图。返回的计数存储在边特征‘cnt’中,写回映射以张量形式返回。
>>> sg, wm = dgl.to_simple(g, return_counts='cnt', writeback_mapping=True) >>> sg.ndata['h'] tensor([[0.], [1.], [2.]]) >>> u, v, eid = sg.edges(form='all') >>> u tensor([0, 1, 2]) >>> v tensor([1, 2, 0]) >>> eid tensor([0, 1, 2]) >>> sg.edata['cnt'] tensor([1, 2, 1]) >>> wm tensor([0, 1, 2, 1]) >>> 'h' in g.edata False
异构图
>>> g = dgl.heterograph({ ... ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])), ... ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1])) ... }) >>> g.nodes['game'].data['hv'] = th.ones(3, 1) >>> g.edges['plays'].data['he'] = th.zeros(3, 1)
返回的计数存储在每种边类型的默认边特征‘count’中。
>>> sg, wm = dgl.to_simple(g, copy_ndata=False, writeback_mapping=True) >>> sg Graph(num_nodes={'game': 3, 'user': 3}, num_edges={('user', 'wins', 'user'): 4, ('game', 'plays', 'user'): 3}, metagraph=[('user', 'user'), ('game', 'user')]) >>> sg.edges(etype='wins') (tensor([0, 2, 0, 2]), tensor([1, 1, 2, 0])) >>> wm[('user', 'wins', 'user')] tensor([0, 1, 2, 1, 3]) >>> sg.edges(etype='plays') (tensor([2, 1, 1]), tensor([1, 2, 1])) >>> wm[('user', 'plays', 'game')] tensor([0, 1, 2]) >>> 'hv' in sg.nodes['game'].data False >>> 'he' in sg.edges['plays'].data False >>> sg.edata['count'] {('user', 'wins', 'user'): tensor([1, 2, 1, 1]) ('user', 'plays', 'game'): tensor([1, 1, 1])}