2.1 内置函数和消息传递API

(中文版)

在DGL中,消息函数接受一个参数edges,它是一个EdgeBatch实例。在消息传递过程中,DGL内部生成它来表示一批边。它有三个成员srcdstdata,分别用于访问源节点、目标节点和边的特征。

reduce 函数 接受一个参数 nodes,它是一个 NodeBatch 实例。在消息传递过程中, DGL 内部生成它来表示一批节点。它有成员 mailbox 用于访问该批次节点接收到的消息。 一些最常见的 reduce 操作包括 summaxmin 等。

更新函数 接受一个参数 nodes,如上所述。 此函数操作来自 reduce function 的聚合结果,通常在最后一步将其与节点的原始特征结合,并将结果保存为节点特征。

DGL 已经实现了常用的消息函数和归约函数,并将其作为内置函数放在命名空间 dgl.function 中。一般来说,DGL 建议尽可能使用内置函数,因为它们经过了高度优化,并且能够自动处理维度广播。

如果你的消息传递函数无法使用内置函数实现,你可以实现用户定义的消息/减少函数(也称为UDF)。

内置的消息函数可以是一元或二元的。DGL 支持 copy 作为一元函数。对于二元函数,DGL 支持 add, sub, mul, div, dot。消息内置函数的命名约定是 u 代表 src 节点,v 代表 dst 节点,e 代表 edges。这些函数的参数是字符串,表示相应节点和边的输入和输出字段名称。支持的内置函数列表可以在 DGL 内置函数 中找到。例如,要将 hu 特征从源节点和 hv 特征从目标节点相加,然后将结果保存在边的 he 字段中,可以使用内置函数 dgl.function.u_add_v('hu', 'hv', 'he')。这相当于消息 UDF:

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

内置的reduce函数支持操作summaxminmean。Reduce函数通常有两个参数,一个用于mailbox中的字段名,一个用于节点特征中的字段名,两者都是字符串。例如,dgl.function.sum('m', 'h')等同于将消息m相加的Reduce UDF:

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

有关UDF的高级用法,请参阅用户定义函数

也可以通过apply_edges()仅调用边缘计算而不调用消息传递。apply_edges()接受一个消息函数作为参数,并默认更新所有边的特征。例如:

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

对于消息传递,update_all() 是一个高级API,它将消息生成、消息聚合和节点更新合并到一个调用中,从而为整体优化留下了空间。

update_all() 的参数包括消息函数、归约函数和更新函数。可以在 update_all 外部调用更新函数,并且在调用 update_all() 时不指定它。DGL 推荐这种方法,因为更新函数通常可以写成纯张量操作,使代码简洁。例如:

def update_all_example(graph):
    # store the result in graph.ndata['ft']
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # Call update function outside of update_all
    final_ft = graph.ndata['ft'] * 2
    return final_ft

此调用将通过乘以源节点特征 ft 和边特征 a 生成消息 m,汇总消息 m 以更新节点特征 ft,最后将 ft 乘以 2 以获得结果 final_ft。调用后,DGL 将清除中间消息 m。上述函数的数学公式为:

\[{final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ji})\]

DGL的内置函数支持浮点数据类型,即特征必须是half (float16) /float/double张量。 float16数据类型默认是禁用的,因为它对GPU的计算能力有最低要求,即sm_53(Pascal、Volta、Turing和Ampere架构)。

用户可以通过从源代码编译DGL来启用float16进行混合精度训练 (详情请参见混合精度训练教程)。