2.2 编写高效的消息传递代码

(中文版)

DGL优化了消息传递的内存消耗和计算速度。利用这些优化的常见做法是将自己的消息传递功能构建为update_all()调用与内置函数作为参数的组合。

除此之外,考虑到某些图中边的数量远大于节点的数量,避免从节点到边的不必要内存复制是有益的。对于某些情况,如 GATConv, 其中需要在边上保存消息,需要调用 apply_edges() 并使用内置函数。有时边上的消息可能是高维的,这会消耗大量内存。 DGL 建议尽可能保持边特征的维度较低。

这里有一个示例,展示了如何通过在节点上拆分操作来实现这一点。该方法执行以下操作:连接src特征和dst特征,然后应用一个线性层,即\(W\times (u || v)\)srcdst特征的维度较高,而线性层输出的维度较低。一个直接的实现可能如下:

import torch
import torch.nn as nn

linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
     return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear

建议的实现将线性操作分为两部分, 一部分应用于src特征,另一部分应用于dst特征。 然后在最后阶段将线性操作的输出添加到边上, 即执行\(W_l\times u + W_r \times v\)。这是因为 \(W \times (u||v) = W_l \times u + W_r \times v\),其中\(W_l\)\(W_r\)分别是矩阵\(W\)的左半部分和右半部分:

import dgl.function as fn

linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

上述两种实现在数学上是等价的。后一种实现更为高效,因为它不需要在边上保存feat_src和feat_dst,这样更节省内存。此外,加法可以通过DGL的内置函数u_add_v()进行优化,这进一步加快了计算速度并节省了内存占用。