2.2 编写高效的消息传递代码
DGL优化了消息传递的内存消耗和计算速度。利用这些优化的常见做法是将自己的消息传递功能构建为update_all()
调用与内置函数作为参数的组合。
除此之外,考虑到某些图中边的数量远大于节点的数量,避免从节点到边的不必要内存复制是有益的。对于某些情况,如
GATConv
,
其中需要在边上保存消息,需要调用
apply_edges()
并使用内置函数。有时边上的消息可能是高维的,这会消耗大量内存。
DGL 建议尽可能保持边特征的维度较低。
这里有一个示例,展示了如何通过在节点上拆分操作来实现这一点。该方法执行以下操作:连接src
特征和dst
特征,然后应用一个线性层,即\(W\times (u || v)\)。src
和dst
特征的维度较高,而线性层输出的维度较低。一个直接的实现可能如下:
import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear
建议的实现将线性操作分为两部分,
一部分应用于src
特征,另一部分应用于dst
特征。
然后在最后阶段将线性操作的输出添加到边上,
即执行\(W_l\times u + W_r \times v\)。这是因为
\(W \times (u||v) = W_l \times u + W_r \times v\),其中\(W_l\)
和\(W_r\)分别是矩阵\(W\)的左半部分和右半部分:
import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
上述两种实现在数学上是等价的。后一种实现更为高效,因为它不需要在边上保存feat_src和feat_dst,这样更节省内存。此外,加法可以通过DGL的内置函数u_add_v()
进行优化,这进一步加快了计算速度并节省了内存占用。