dgl.DGLGraph.multi_update_all
- DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)[source]
沿着所有边发送消息,首先按类型进行归约,然后跨不同类型进行归约,最后更新所有节点的节点特征。
- Parameters:
etype_dict (dict) –
用于边类型消息传递的参数。键是边类型,而值是消息传递参数。
允许的键格式有:
(str, str, str)
表示源节点类型、边类型和目标节点类型。或者一个
str
边类型名称,如果该名称可以唯一标识图中的三元组格式。
值必须是一个元组
(message_func, reduce_func, [apply_node_func])
,其中cross_reducer (str 或 可调用函数) – 交叉类型归约器。可以是
"sum"
,"min"
,"max"
,"mean"
,"stack"
或一个可调用函数。如果提供了一个可调用函数,输入参数必须是 一个包含来自每种边类型的聚合结果的张量列表,并且函数的输出必须是一个单一的张量。apply_node_func (callable, optional) – 一个可选的apply函数,在消息按类型和跨类型减少后执行。 它必须是一个用户定义函数。
注释
DGL 建议在类型化消息传递参数中使用 DGL 的内置函数作为 message_func 和 reduce_func,因为在这种情况下,DGL 将调用高效的内核,避免将节点特征复制到边特征。
示例
>>> import dgl >>> import dgl.function as fn >>> import torch
实例化一个异构图。
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 1]), ... ('game', 'attracts', 'user'): ([0], [1]) ... }) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
更新全部。
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... "sum") >>> g.nodes['user'].data['h'] tensor([[0.], [4.]])
用户定义的交叉缩减器,等同于“sum”。
>>> def cross_sum(flist): ... return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]
使用用户定义的交叉减速器。
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... cross_sum)