dgl.DGLGraph.multi_update_all

DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)[source]

沿着所有边发送消息,首先按类型进行归约,然后跨不同类型进行归约,最后更新所有节点的节点特征。

Parameters:
  • etype_dict (dict) –

    用于边类型消息传递的参数。键是边类型,而值是消息传递参数。

    允许的键格式有:

    • (str, str, str) 表示源节点类型、边类型和目标节点类型。

    • 或者一个 str 边类型名称,如果该名称可以唯一标识图中的三元组格式。

    值必须是一个元组 (message_func, reduce_func, [apply_node_func]),其中

    • message_funcdgl.function.BuiltinFunction 或 callable

      用于沿边生成消息的消息函数。它必须是 DGL 内置函数用户定义函数

    • reduce_funcdgl.function.BuiltinFunction 或 callable

      用于聚合消息的归约函数。它必须是 DGL 内置函数用户定义函数

    • apply_node_funccallable, 可选

      一个可选的节点应用函数,用于在消息归约后进一步更新节点特征。它必须是 用户定义函数

  • cross_reducer (str可调用函数) – 交叉类型归约器。可以是 "sum", "min", "max", "mean", "stack" 或一个可调用函数。如果提供了一个可调用函数,输入参数必须是 一个包含来自每种边类型的聚合结果的张量列表,并且函数的输出必须是一个单一的张量。

  • apply_node_func (callable, optional) – 一个可选的apply函数,在消息按类型和跨类型减少后执行。 它必须是一个用户定义函数

注释

DGL 建议在类型化消息传递参数中使用 DGL 的内置函数作为 message_func 和 reduce_func,因为在这种情况下,DGL 将调用高效的内核,避免将节点特征复制到边特征。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch

实例化一个异构图。

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),
...     ('game', 'attracts', 'user'): ([0], [1])
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

更新全部。

>>> g.multi_update_all(
...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
        [4.]])

用户定义的交叉缩减器,等同于“sum”。

>>> def cross_sum(flist):
...     return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]

使用用户定义的交叉减速器。

>>> g.multi_update_all(
...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
... cross_sum)