编写你自己的GNN模块

有时,您的模型不仅仅是简单地堆叠现有的GNN模块。 例如,您可能希望通过考虑节点重要性或边权重来发明一种新的邻居信息聚合方式。

在本教程结束时,您将能够

  • 了解DGL的消息传递API。

  • 自己实现GraphSAGE卷积模块。

本教程假设您已经了解训练GNN进行节点分类的基础知识

(预计时间:10分钟)

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

消息传递和图神经网络

DGL遵循由Gilmer等人提出的消息传递神经网络所启发的消息传递范式。本质上,他们发现许多GNN模型可以适应以下框架:

\[m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)\]
\[m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}\]
\[h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)\]

其中DGL将\(M^{(l)}\)称为消息函数\(\sum\)称为 归约函数\(U^{(l)}\)称为更新函数。请注意, 这里的\(\sum\)可以代表任何函数,并不一定是 求和。

例如,GraphSAGE 卷积 (Hamilton 等人, 2017) 采用以下数学形式:

\[h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}\]
\[h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)\]

你可以看到消息传递是定向的:从一个节点\(u\)发送到另一个节点\(v\)的消息不一定与从节点\(v\)发送到节点\(u\)的相反方向的消息相同。

尽管DGL通过dgl.nn.SAGEConv内置了对GraphSAGE的支持,但这里展示了如何在DGL中自行实现GraphSAGE卷积。

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """

    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata["h"] = h
            # update_all is a message passing API.
            g.update_all(
                message_func=fn.copy_u("h", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

这段代码的核心部分是 g.update_all 函数,它收集并平均邻居特征。这里有三个概念:

  • 消息函数 fn.copy_u('h', 'm') 将节点特征从名称 'h' 复制为发送给邻居的 消息,名称为 'm'

  • Reduce函数 fn.mean('m', 'h_N') 对名称 'm' 下接收到的所有消息进行平均,并将结果保存为新的节点特征 'h_N'

  • update_all 告诉 DGL 触发所有节点和边的消息和减少函数。

之后,你可以堆叠你自己的GraphSAGE卷积层来形成一个多层的GraphSAGE网络。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

Training loop

以下代码用于数据加载和训练循环,直接从入门教程中复制。

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
            print(
                "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                )
            )


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.953, val acc: 0.122 (best 0.122), test acc: 0.130 (best 0.130)
In epoch 5, loss: 1.871, val acc: 0.416 (best 0.424), test acc: 0.405 (best 0.414)
In epoch 10, loss: 1.712, val acc: 0.466 (best 0.466), test acc: 0.453 (best 0.453)
In epoch 15, loss: 1.473, val acc: 0.636 (best 0.636), test acc: 0.637 (best 0.637)
In epoch 20, loss: 1.167, val acc: 0.688 (best 0.688), test acc: 0.683 (best 0.683)
In epoch 25, loss: 0.839, val acc: 0.736 (best 0.736), test acc: 0.738 (best 0.738)
In epoch 30, loss: 0.546, val acc: 0.770 (best 0.770), test acc: 0.764 (best 0.764)
In epoch 35, loss: 0.326, val acc: 0.768 (best 0.772), test acc: 0.776 (best 0.771)
In epoch 40, loss: 0.186, val acc: 0.770 (best 0.772), test acc: 0.781 (best 0.771)
In epoch 45, loss: 0.107, val acc: 0.768 (best 0.772), test acc: 0.790 (best 0.771)
In epoch 50, loss: 0.064, val acc: 0.770 (best 0.772), test acc: 0.790 (best 0.771)
In epoch 55, loss: 0.041, val acc: 0.772 (best 0.772), test acc: 0.791 (best 0.771)
In epoch 60, loss: 0.029, val acc: 0.772 (best 0.772), test acc: 0.790 (best 0.771)
In epoch 65, loss: 0.021, val acc: 0.776 (best 0.776), test acc: 0.786 (best 0.790)
In epoch 70, loss: 0.017, val acc: 0.772 (best 0.776), test acc: 0.782 (best 0.790)
In epoch 75, loss: 0.014, val acc: 0.766 (best 0.776), test acc: 0.780 (best 0.790)
In epoch 80, loss: 0.012, val acc: 0.768 (best 0.776), test acc: 0.778 (best 0.790)
In epoch 85, loss: 0.010, val acc: 0.768 (best 0.776), test acc: 0.778 (best 0.790)
In epoch 90, loss: 0.009, val acc: 0.766 (best 0.776), test acc: 0.775 (best 0.790)
In epoch 95, loss: 0.008, val acc: 0.768 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 100, loss: 0.007, val acc: 0.770 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 105, loss: 0.007, val acc: 0.766 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 110, loss: 0.006, val acc: 0.766 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 115, loss: 0.006, val acc: 0.766 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 120, loss: 0.006, val acc: 0.766 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 125, loss: 0.005, val acc: 0.768 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 130, loss: 0.005, val acc: 0.768 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 135, loss: 0.005, val acc: 0.768 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 140, loss: 0.004, val acc: 0.768 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 145, loss: 0.004, val acc: 0.768 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 150, loss: 0.004, val acc: 0.766 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 155, loss: 0.004, val acc: 0.764 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 160, loss: 0.004, val acc: 0.764 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 165, loss: 0.003, val acc: 0.764 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 170, loss: 0.003, val acc: 0.766 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 175, loss: 0.003, val acc: 0.766 (best 0.776), test acc: 0.774 (best 0.790)
In epoch 180, loss: 0.003, val acc: 0.764 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 185, loss: 0.003, val acc: 0.764 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 190, loss: 0.003, val acc: 0.764 (best 0.776), test acc: 0.773 (best 0.790)
In epoch 195, loss: 0.003, val acc: 0.762 (best 0.776), test acc: 0.773 (best 0.790)

更多自定义选项

在DGL中,我们在dgl.function包下提供了许多内置的消息和减少函数。您可以在API文档中找到更多详细信息。

这些API允许快速实现新的图卷积模块。 例如,以下实现了一个新的SAGEConv,它使用加权平均来聚合邻居表示。请注意,edata成员可以保存边特征,这些特征也可以参与消息传递。

class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """

    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h, w):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
            g.ndata["h"] = h
            g.edata["w"] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

由于此数据集中的图没有边权重,我们在模型的forward()函数中手动将所有边权重分配为一。您可以将其替换为您自己的边权重。

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)
In epoch 0, loss: 1.952, val acc: 0.156 (best 0.156), test acc: 0.144 (best 0.144)
In epoch 5, loss: 1.869, val acc: 0.182 (best 0.182), test acc: 0.175 (best 0.175)
In epoch 10, loss: 1.705, val acc: 0.436 (best 0.436), test acc: 0.437 (best 0.437)
In epoch 15, loss: 1.461, val acc: 0.590 (best 0.590), test acc: 0.553 (best 0.553)
In epoch 20, loss: 1.152, val acc: 0.634 (best 0.634), test acc: 0.598 (best 0.598)
In epoch 25, loss: 0.819, val acc: 0.684 (best 0.684), test acc: 0.644 (best 0.644)
In epoch 30, loss: 0.520, val acc: 0.714 (best 0.714), test acc: 0.691 (best 0.691)
In epoch 35, loss: 0.301, val acc: 0.748 (best 0.748), test acc: 0.734 (best 0.734)
In epoch 40, loss: 0.168, val acc: 0.764 (best 0.764), test acc: 0.743 (best 0.743)
In epoch 45, loss: 0.095, val acc: 0.764 (best 0.764), test acc: 0.749 (best 0.743)
In epoch 50, loss: 0.057, val acc: 0.766 (best 0.768), test acc: 0.751 (best 0.750)
In epoch 55, loss: 0.036, val acc: 0.764 (best 0.768), test acc: 0.749 (best 0.750)
In epoch 60, loss: 0.025, val acc: 0.758 (best 0.768), test acc: 0.749 (best 0.750)
In epoch 65, loss: 0.019, val acc: 0.762 (best 0.768), test acc: 0.752 (best 0.750)
In epoch 70, loss: 0.015, val acc: 0.760 (best 0.768), test acc: 0.752 (best 0.750)
In epoch 75, loss: 0.012, val acc: 0.760 (best 0.768), test acc: 0.754 (best 0.750)
In epoch 80, loss: 0.011, val acc: 0.760 (best 0.768), test acc: 0.755 (best 0.750)
In epoch 85, loss: 0.009, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 90, loss: 0.008, val acc: 0.756 (best 0.768), test acc: 0.754 (best 0.750)
In epoch 95, loss: 0.007, val acc: 0.758 (best 0.768), test acc: 0.753 (best 0.750)
In epoch 100, loss: 0.007, val acc: 0.758 (best 0.768), test acc: 0.752 (best 0.750)
In epoch 105, loss: 0.006, val acc: 0.758 (best 0.768), test acc: 0.752 (best 0.750)
In epoch 110, loss: 0.006, val acc: 0.758 (best 0.768), test acc: 0.752 (best 0.750)
In epoch 115, loss: 0.005, val acc: 0.758 (best 0.768), test acc: 0.754 (best 0.750)
In epoch 120, loss: 0.005, val acc: 0.756 (best 0.768), test acc: 0.755 (best 0.750)
In epoch 125, loss: 0.005, val acc: 0.754 (best 0.768), test acc: 0.755 (best 0.750)
In epoch 130, loss: 0.005, val acc: 0.754 (best 0.768), test acc: 0.755 (best 0.750)
In epoch 135, loss: 0.004, val acc: 0.754 (best 0.768), test acc: 0.755 (best 0.750)
In epoch 140, loss: 0.004, val acc: 0.754 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 145, loss: 0.004, val acc: 0.756 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 150, loss: 0.004, val acc: 0.756 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 155, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.757 (best 0.750)
In epoch 160, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 165, loss: 0.003, val acc: 0.756 (best 0.768), test acc: 0.757 (best 0.750)
In epoch 170, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 175, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.757 (best 0.750)
In epoch 180, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 185, loss: 0.003, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 190, loss: 0.002, val acc: 0.758 (best 0.768), test acc: 0.756 (best 0.750)
In epoch 195, loss: 0.002, val acc: 0.758 (best 0.768), test acc: 0.755 (best 0.750)

更多自定义功能通过用户定义函数

DGL 允许用户自定义消息和减少函数以实现最大的表达能力。这里是一个用户定义的消息函数,它等同于 fn.u_mul_e('h', 'w', 'm')

def u_mul_e_udf(edges):
    return {"m": edges.src["h"] * edges.data["w"]}

edges 有三个成员:srcdatadst,分别表示所有边的源节点特征、边特征和目标节点特征。

你也可以编写自己的reduce函数。例如,以下代码等同于内置的fn.mean('m', 'h_N')函数,该函数用于对传入的消息进行平均:

def mean_udf(nodes):
    return {"h_N": nodes.mailbox["m"].mean(1)}

简而言之,DGL将根据节点的入度对节点进行分组,并且对于每个组,DGL将传入的消息沿着第二维度堆叠。然后,您可以沿着第二维度执行归约操作以聚合消息。

有关使用用户定义函数自定义消息和reduce函数的更多详细信息,请参阅API参考

编写自定义GNN模块的最佳实践

DGL 推荐以下按偏好排序的实践:

  • 使用 dgl.nn 模块。

  • 使用dgl.nn.functional函数,这些函数包含较低级别的复杂操作,例如为每个节点在传入边上计算softmax。

  • 使用内置消息和减少函数的update_all

  • 使用用户定义的消息或减少函数。

接下来是什么?

# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'

脚本的总运行时间: (0 分钟 33.472 秒)

Gallery generated by Sphinx-Gallery