3.2 DGL NN模块前向函数

(中文版)

在NN模块中,forward() 函数负责实际的消息传递和计算。与通常以张量作为参数的PyTorch NN模块相比,DGL NN模块还额外接受一个参数 dgl.DGLGraphforward() 函数的工作可以分为三个部分:

  • 图形检查和图形类型规范。

  • 消息传递。

  • 功能更新。

本节其余部分将深入探讨SAGEConv示例中的forward()函数。

图形检查和图形类型规范

def forward(self, graph, feat):
    with graph.local_scope():
        # Specify graph type then expand input feature according to graph type
        feat_src, feat_dst = expand_as_pair(feat, graph)

forward() 需要处理输入中的许多边缘情况,这些情况可能导致计算和消息传递中的无效值。在像 GraphConv 这样的卷积模块中,一个典型的检查是验证输入图没有0入度的节点。当一个节点的入度为0时,mailbox 将为空,并且归约函数将产生全零值。这可能会导致模型性能的无声回归。然而,在 SAGEConv 模块中,聚合表示将与原始节点特征连接,forward() 的输出不会全为零。在这种情况下不需要进行此类检查。

DGL NN 模块应该可以在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)、子图块(第6章:大型图上的随机训练)。

SAGEConv的数学公式是:

\[h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)\]
\[h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1}) + b \right)\]
\[h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l+1})\]

需要根据图类型指定源节点特征 feat_src 和目标节点特征 feat_dstexpand_as_pair() 是一个函数,用于指定图类型并将 feat 扩展为 feat_srcfeat_dst。 该函数的详细信息如下所示。

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # Bipartite graph case
        return input_
    elif g is not None and g.is_block:
        # Subgraph block case
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # Homogeneous graph case
        return input_, input_

对于同质全图训练,源节点和目标节点是相同的。它们都是图中的所有节点。

对于异构情况,图可以分割成几个二分图,每个关系一个。关系表示为 (src_type, edge_type, dst_dtype)。当它识别到输入特征 feat 是一个元组时,它将把图视为二分图。元组中的第一个 元素将是源节点特征,第二个元素将是目标节点特征。

在小批量训练中,计算是基于一组目标节点采样的子图进行的。在DGL中,这个子图被称为block。在块创建阶段,dst nodes位于节点列表的前面。可以通过索引[0:g.number_of_dst_nodes()]找到feat_dst

在确定 feat_srcfeat_dst 之后,上述三种图类型的计算是相同的。

消息传递和减少

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape

if self._aggre_type == 'mean':
    graph.srcdata['h'] = feat_src
    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
    check_eq_shape(feat)
    graph.srcdata['h'] = feat_src
    graph.dstdata['h'] = feat_dst
    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
    # divide in_degrees
    degs = graph.in_degrees().to(feat_dst)
    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
else:
    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
    rst = self.fc_neigh(h_neigh)
else:
    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

代码实际上执行了消息传递和减少计算。这部分代码因模块而异。请注意,上述代码中的所有消息传递都是使用update_all() API和内置的消息/减少函数来实现的,以充分利用DGL的性能优化,如2.2 编写高效的消息传递代码中所述。

减少输出后更新功能

# activation
if self.activation is not None:
    rst = self.activation(rst)
# normalization
if self.norm is not None:
    rst = self.norm(rst)
return rst

forward() 函数的最后一部分是在 reduce function 之后更新特征。常见的更新操作包括根据对象构建阶段设置的选项应用激活函数和归一化。