6.6 实现自定义GNN模块以进行小批量训练

(中文版)

注意

本教程 的内容与本节的同构图案例相似。

如果您熟悉如何为同质或异质图编写自定义GNN模块以更新整个图(参见第3章:构建GNN模块),那么在MFGs上计算的代码是类似的,不同之处在于节点被分为输入节点和输出节点。

例如,考虑以下自定义图卷积模块代码。请注意,它不一定是最有效的实现之一 - 它们仅作为自定义GNN模块可能是什么样子的示例。

class CustomGraphConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))

如果你有一个用于全图的自定义消息传递神经网络模块,并且你想让它适用于MFGs,你只需要重写前向函数如下。请注意,来自全图实现的相应语句已被注释;你可以将原始语句与新语句进行比较。

class CustomGraphConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    # h is now a pair of feature tensors for input and output nodes, instead of
    # a single feature tensor.
    # def forward(self, g, h):
    def forward(self, block, h):
        # with g.local_scope():
        with block.local_scope():
            # g.ndata['h'] = h
            h_src = h
            h_dst = h[:block.number_of_dst_nodes()]
            block.srcdata['h'] = h_src
            block.dstdata['h'] = h_dst

            # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))

            # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
            return self.W(torch.cat(
                [block.dstdata['h'], block.dstdata['h_neigh']], 1))

通常,你需要执行以下操作以使你的NN模块适用于MFGs。

Heterogeneous graphs

对于异构图,编写自定义GNN模块的方式是相似的。例如,考虑以下适用于整个图的模块。

class CustomHeteroGraphConv(nn.Module):
    def __init__(self, g, in_feats, out_feats):
        super().__init__()
        self.Ws = nn.ModuleDict()
        for etype in g.canonical_etypes:
            utype, _, vtype = etype
            self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
        for ntype in g.ntypes:
            self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])

    def forward(self, g, h):
        with g.local_scope():
            for ntype in g.ntypes:
                g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                g.nodes[ntype].data['h_src'] = h[ntype]
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                g.update_all(
                    fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                    etype=etype)
                g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
                    self.Ws[etype](g.nodes[vtype].data['h_neigh'])
            return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}

对于CustomHeteroGraphConv,原则是根据特征是用于输入还是输出,将g.nodes替换为g.srcnodesg.dstnodes

class CustomHeteroGraphConv(nn.Module):
    def __init__(self, g, in_feats, out_feats):
        super().__init__()
        self.Ws = nn.ModuleDict()
        for etype in g.canonical_etypes:
            utype, _, vtype = etype
            self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
        for ntype in g.ntypes:
            self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])

    def forward(self, g, h):
        with g.local_scope():
            for ntype in g.ntypes:
                h_src, h_dst = h[ntype]
                g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                g.srcnodes[ntype].data['h_src'] = h[ntype]
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                g.update_all(
                    fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                    etype=etype)
                g.dstnodes[vtype].data['h_dst'] = \
                    g.dstnodes[vtype].data['h_dst'] + \
                    self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
            return {ntype: g.dstnodes[ntype].data['h_dst']
                    for ntype in g.ntypes}

编写适用于同构图、二分图和MFG的模块

DGL中的所有消息传递模块都适用于同构图、单向二分图(具有两种节点类型和一种边类型)以及具有一种边类型的MFG。本质上,内置DGL神经网络模块的输入图和特征必须满足以下任一情况。

  • 如果输入特征是一对张量,那么输入图必须是单向二分图。

  • 如果输入特征是一个单一的张量,并且输入图是一个MFG,DGL将自动将输出节点上的特征设置为输入节点特征的前几行。

  • 如果输入特征必须是一个单一的张量且输入图不是MFG,那么输入图必须是同质的。

例如,以下是从PyTorch实现的dgl.nn.pytorch.SAGEConv(也可在MXNet和Tensorflow中使用)中简化的内容(去除了归一化,仅处理均值聚合等)。

import dgl.function as fn
class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.W = nn.Linear(in_feats * 2, out_feats)

    def forward(self, g, h):
        if isinstance(h, tuple):
            h_src, h_dst = h
        elif g.is_block:
            h_src = h
            h_dst = h[:g.number_of_dst_nodes()]
        else:
            h_src = h_dst = h

        g.srcdata['h'] = h_src
        g.dstdata['h'] = h_dst
        g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
        return F.relu(
            self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))

第3章:构建GNN模块 还提供了关于 dgl.nn.pytorch.SAGEConv 的详细说明, 它适用于单向二分图、同构图和MFGs。