3.3 异构图卷积模块

(中文版)

HeteroGraphConv 是一个模块级封装,用于在异质图上运行DGL NN模块。其实现逻辑与消息传递级别的API multi_update_all()相同,包括:

  • 每个关系\(r\)内的DGL NN模块。

  • 减少操作,将来自多个关系的相同节点类型的结果合并。

这可以表述为:

\[h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))\]

其中 \(f_r\) 是每个关系 \(r\) 的神经网络模块, \(AGG\) 是聚合函数。

HeteroGraphConv 实现逻辑:

import torch.nn as nn

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # An internal function to get common aggregation functions
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

异构图卷积接受一个字典 mods,该字典将每个关系映射到一个 nn 模块,并设置从多个关系在同一节点类型上聚合结果的函数。

def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
    if mod_args is None:
        mod_args = {}
    if mod_kwargs is None:
        mod_kwargs = {}
    outputs = {nty : [] for nty in g.dsttypes}

除了输入图和输入张量,forward() 函数还接受两个额外的字典参数 mod_argsmod_kwargs。这两个字典的键与 self.mods 相同。它们在调用 self.mods 中对应的神经网络模块时,作为不同类型的自定义参数使用。

创建一个输出字典来保存每个目标类型 nty 的输出张量。请注意,每个 nty 的值是一个列表,表示如果一个节点类型有多个关系将 nty 作为目标类型,则可能会获得多个输出。HeteroGraphConv 将对这些列表进行进一步的聚合。

if g.is_block:
    src_inputs = inputs
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
    src_inputs = dst_inputs = inputs

for stype, etype, dtype in g.canonical_etypes:
    rel_graph = g[stype, etype, dtype]
    if rel_graph.num_edges() == 0:
        continue
    if stype not in src_inputs or dtype not in dst_inputs:
        continue
    dstdata = self.mods[etype](
        rel_graph,
        (src_inputs[stype], dst_inputs[dtype]),
        *mod_args.get(etype, ()),
        **mod_kwargs.get(etype, {}))
    outputs[dtype].append(dstdata)

输入 g 可以是一个异构图或来自异构图的一个子图块。与普通的神经网络模块一样,forward() 函数需要分别处理不同的输入图类型。

每个关系都表示为canonical_etype,即 (stype, etype, dtype)。使用canonical_etype作为键,可以 提取出一个二分图rel_graph。对于二分图,输入特征将被组织为一个元组 (src_inputs[stype], dst_inputs[dtype])。每个关系的NN模块被调用并保存输出。为了避免不必要的调用, 没有边或没有源类型节点的关系将被跳过。

rsts = {}
for nty, alist in outputs.items():
    if len(alist) != 0:
        rsts[nty] = self.agg_fn(alist, nty)

最后,来自多个关系的相同目标节点类型的结果使用self.agg_fn函数进行聚合。可以在HeteroGraphConv的API文档中找到示例。