dgl.DGLGraph.update_all

DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)[source]

沿着指定类型的所有边发送消息,并更新相应目标类型的所有节点。

对于关系类型数量大于1的异构图,沿着所有边发送消息,按类型和跨类型同时减少它们。然后,更新所有节点的节点特征。

Parameters:
  • message_func (dgl.function.BuiltinFunction or callable) – The message function to generate messages along the edges. It must be either a DGL Built-in Function or a User-defined Functions.

  • reduce_func (dgl.function.BuiltinFunction or callable) – The reduce function to aggregate the messages. It must be either a DGL Built-in Function or a User-defined Functions.

  • apply_node_func (callable, optional) – An optional apply function to further update the node features after the message reduction. It must be a User-defined Functions.

  • etype (str or (str, str, str), optional) –

    The type name of the edges. The allowed type name formats are:

    • (str, str, str) for source node type, edge type and destination node type.

    • or one str edge type name if the name can uniquely identify a triplet format in the graph.

    Can be omitted if the graph has only one type of edges.

注释

  • 如果图中的某些节点没有入边,DGL不会为这些节点调用消息和减少函数,并用零填充它们的聚合消息。用户可以通过set_n_initializer()控制填充的值。如果提供了apply_node_func,DGL仍然会调用它。

  • DGL recommends using DGL’s bulit-in function for the message_func and the reduce_func arguments, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch

同构图

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

异构图

>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})

更新全部。

>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
        [0.],
        [3.]])

异质图(关系类型数量 > 1)

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),
...     ('game', 'attracts', 'user'): ([0], [1])
... })

更新全部。

>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h']
tensor([[0.],
        [4.]])