dgl.to_simple

dgl.to_simple(g, return_counts='count', writeback_mapping=False, copy_ndata=True, copy_edata=False, aggregator='arbitrary')[source]

将图转换为没有平行边的简单图并返回。

对于具有多种边类型的异构图,DGL将具有相同边类型和端点的边视为平行边并将其移除。 可选地,可以通过指定return_counts参数来获取平行边的数量。要获取输入图中边ID到结果图中边ID的映射,请将writeback_mapping设置为true。

Parameters:
  • g (DGLGraph) – 输入图。必须在CPU上。

  • return_counts (str, optional) –

    如果给定,原始图中每条边的计数将作为边特征存储在名称return_counts下。同名的旧特征将被替换。

    (默认值: “count”)

  • writeback_mapping (bool, optional) –

    如果为True,则为每种边类型返回一个额外的写回映射。写回映射是一个记录从输入图中的边ID到结果图中的边ID的映射的张量。如果图是异质的,DGL将返回一个边类型和此类张量的字典。

    如果为False,则只返回简单的图。

    (默认值: False)

  • copy_ndata (bool, optional) –

    如果为True,简单图的节点特征将从原始图中复制。

    如果为False,简单图将不会有任何节点特征。

    (默认值: True)

  • copy_edata (bool, optional) –

    如果为True,简单图的边特征将从原始图中复制。如果两个节点(u, v)之间存在重复边,则边的特征是重复边特征的聚合。

    如果为False,简单图将没有任何边特征。

    (默认值:False)

  • aggregator (str, optional) –

    指示如何合并重复边的边特征。 如果为arbitrary,则选择其中一个重复边的特征。 如果为sum,则计算重复边特征的总和。 如果为mean,则计算重复边特征的平均值。

    (默认值: arbitrary)

Returns:

  • DGLGraph – 图。

  • tensor or dict of tensor – 写回映射。仅当 writeback_mapping 为 True 时。

注释

If copy_ndata is True, the resulting graph will share the node feature tensors with the input graph. Hence, users should try to avoid in-place operations which will be visible to both graphs.

This function discards the batch information. Please use dgl.DGLGraph.set_batch_num_nodes() and dgl.DGLGraph.set_batch_num_edges() on the transformed graph to maintain the information.

示例

同构图

创建一个图表来演示 to_simple API。 在原始图表中,1 和 2 之间有多条边。

>>> import dgl
>>> import torch as th
>>> g = dgl.graph((th.tensor([0, 1, 2, 1]), th.tensor([1, 2, 0, 2])))
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.], [6.]])

将图转换为简单图。返回的计数存储在边特征‘cnt’中,写回映射以张量形式返回。

>>> sg, wm = dgl.to_simple(g, return_counts='cnt', writeback_mapping=True)
>>> sg.ndata['h']
tensor([[0.],
        [1.],
        [2.]])
>>> u, v, eid = sg.edges(form='all')
>>> u
tensor([0, 1, 2])
>>> v
tensor([1, 2, 0])
>>> eid
tensor([0, 1, 2])
>>> sg.edata['cnt']
tensor([1, 2, 1])
>>> wm
tensor([0, 1, 2, 1])
>>> 'h' in g.edata
False

异构图

>>> g = dgl.heterograph({
...     ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])),
...     ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))
... })
>>> g.nodes['game'].data['hv'] = th.ones(3, 1)
>>> g.edges['plays'].data['he'] = th.zeros(3, 1)

返回的计数存储在每种边类型的默认边特征‘count’中。

>>> sg, wm = dgl.to_simple(g, copy_ndata=False, writeback_mapping=True)
>>> sg
Graph(num_nodes={'game': 3, 'user': 3},
      num_edges={('user', 'wins', 'user'): 4, ('game', 'plays', 'user'): 3},
      metagraph=[('user', 'user'), ('game', 'user')])
>>> sg.edges(etype='wins')
(tensor([0, 2, 0, 2]), tensor([1, 1, 2, 0]))
>>> wm[('user', 'wins', 'user')]
tensor([0, 1, 2, 1, 3])
>>> sg.edges(etype='plays')
(tensor([2, 1, 1]), tensor([1, 2, 1]))
>>> wm[('user', 'plays', 'game')]
tensor([0, 1, 2])
>>> 'hv' in sg.nodes['game'].data
False
>>> 'he' in sg.edges['plays'].data
False
>>> sg.edata['count']
{('user', 'wins', 'user'): tensor([1, 2, 1, 1])
 ('user', 'plays', 'game'): tensor([1, 1, 1])}