dgl.DGLGraph.prop_nodes

DGLGraph.prop_nodes(nodes_generator, message_func, reduce_func, apply_node_func=None, etype=None)[source]

通过依次在节点上触发pull(),使用图遍历传播消息。

遍历顺序由nodes_generator指定。它生成节点前沿,即节点的列表或张量。同一前沿中的节点将一起触发,而不同前沿中的节点将根据生成顺序触发。

Parameters:
  • nodes_generator (iterable[node IDs]) – 节点边界的生成器。每个边界是一组存储在Tensor或python可迭代对象中的节点ID。 它指定了哪些节点在每一步执行pull()

  • message_func (dgl.function.BuiltinFunction可调用) – 用于沿边生成消息的消息函数。 它必须是 DGL 内置函数用户自定义函数

  • reduce_func (dgl.function.BuiltinFunction可调用) – 用于聚合消息的reduce函数。 它必须是一个DGL内置函数或一个用户自定义函数

  • apply_node_func (callable, optional) – 一个可选的apply函数,用于在消息缩减后进一步更新节点特征。它必须是一个用户自定义函数

  • etype (str(str, str, str), 可选) –

    边的类型名称。允许的类型名称格式为:

    • (str, str, str) 用于源节点类型、边类型和目标节点类型。

    • 或者一个 str 边类型名称,如果该名称可以唯一标识图中的三元组格式。

    如果图中只有一种类型的边,则可以省略。

示例

>>> import torch
>>> import dgl
>>> import dgl.function as fn

实例化一个异构图并执行多轮消息传递。

>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_u('h', 'm'),
...                         fn.sum('m', 'h'), etype='follows')
tensor([[1.],
        [2.],
        [1.],
        [2.],
        [3.]])

另请参阅

prop_edges