2.1 内置函数和消息传递API
在DGL中,消息函数接受一个参数edges
,它是一个EdgeBatch
实例。在消息传递过程中,DGL内部生成它来表示一批边。它有三个成员src
、dst
和data
,分别用于访问源节点、目标节点和边的特征。
reduce 函数 接受一个参数 nodes
,它是一个
NodeBatch
实例。在消息传递过程中,
DGL 内部生成它来表示一批节点。它有成员
mailbox
用于访问该批次节点接收到的消息。
一些最常见的 reduce 操作包括 sum
、max
、min
等。
更新函数 接受一个参数 nodes
,如上所述。
此函数操作来自 reduce function
的聚合结果,通常在最后一步将其与节点的原始特征结合,并将结果保存为节点特征。
DGL 已经实现了常用的消息函数和归约函数,并将其作为内置函数放在命名空间 dgl.function
中。一般来说,DGL 建议尽可能使用内置函数,因为它们经过了高度优化,并且能够自动处理维度广播。
如果你的消息传递函数无法使用内置函数实现,你可以实现用户定义的消息/减少函数(也称为UDF)。
内置的消息函数可以是一元或二元的。DGL 支持 copy
作为一元函数。对于二元函数,DGL 支持 add
, sub
, mul
, div
, dot
。消息内置函数的命名约定是 u
代表 src
节点,v
代表 dst
节点,e
代表 edges
。这些函数的参数是字符串,表示相应节点和边的输入和输出字段名称。支持的内置函数列表可以在 DGL 内置函数 中找到。例如,要将 hu
特征从源节点和 hv
特征从目标节点相加,然后将结果保存在边的 he
字段中,可以使用内置函数 dgl.function.u_add_v('hu', 'hv', 'he')
。这相当于消息 UDF:
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
内置的reduce函数支持操作sum
、max
、min
和mean
。Reduce函数通常有两个参数,一个用于mailbox
中的字段名,一个用于节点特征中的字段名,两者都是字符串。例如,dgl.function.sum('m', 'h')
等同于将消息m
相加的Reduce UDF:
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
有关UDF的高级用法,请参阅用户定义函数。
也可以通过apply_edges()
仅调用边缘计算而不调用消息传递。apply_edges()
接受一个消息函数作为参数,并默认更新所有边的特征。例如:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
对于消息传递,update_all()
是一个高级API,它将消息生成、消息聚合和节点更新合并到一个调用中,从而为整体优化留下了空间。
update_all()
的参数包括消息函数、归约函数和更新函数。可以在 update_all
外部调用更新函数,并且在调用 update_all()
时不指定它。DGL 推荐这种方法,因为更新函数通常可以写成纯张量操作,使代码简洁。例如:
def update_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# Call update function outside of update_all
final_ft = graph.ndata['ft'] * 2
return final_ft
此调用将通过乘以源节点特征 ft
和边特征 a
生成消息 m
,汇总消息 m
以更新节点特征 ft
,最后将 ft
乘以 2 以获得结果 final_ft
。调用后,DGL 将清除中间消息 m
。上述函数的数学公式为:
DGL的内置函数支持浮点数据类型,即特征必须是half
(float16
) /float
/double
张量。
float16
数据类型默认是禁用的,因为它对GPU的计算能力有最低要求,即sm_53
(Pascal、Volta、Turing和Ampere架构)。
用户可以通过从源代码编译DGL来启用float16进行混合精度训练 (详情请参见混合精度训练教程)。