注意
Go to the end to download the full example code
使用图神经网络进行链接预测
在介绍中,你已经学习了使用GNN进行节点分类的基本工作流程,即预测图中节点的类别。本教程将教你如何训练GNN进行链接预测,即预测图中两个任意节点之间是否存在边。
在本教程结束时,您将能够
构建一个基于GNN的链接预测模型。
在DGL提供的小型数据集上训练和评估模型。
(预计时间:28分钟)
import itertools
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
使用GNN进行链接预测的概述
许多应用程序,如社交推荐、物品推荐、知识图谱补全等,都可以被表述为链接预测,即预测两个特定节点之间是否存在边。本教程展示了一个预测引用网络中两篇论文之间是否存在引用关系(无论是引用还是被引用)的示例。
本教程将链接预测问题表述为一个二元分类问题,如下所示:
将图中的边视为正例。
采样一些不存在的边(即没有边连接的节点对)作为负样本。
将正例和负例划分为训练集和测试集。
使用任何二元分类指标(如曲线下面积(AUC))评估模型。
注意
该实践源自 SEAL, 尽管这里的模型没有使用他们的节点标记思想。
在某些领域,如大规模推荐系统或信息检索,您可能更倾向于使用强调前K个预测性能良好的指标。在这些情况下,您可能需要考虑其他指标,如平均精度均值,并使用其他负采样方法,这些内容超出了本教程的范围。
加载图和特征
在介绍之后,本教程首先加载Cora数据集。
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
准备训练和测试集
本教程随机选取测试集中10%的边作为正例,其余的边留作训练集。然后,它在两个集合中为负例采样相同数量的边。
# Split edge set for training and testing
u, v = g.edges()
eids = np.arange(g.num_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1)
train_size = g.num_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]
# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.num_nodes())
neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), g.num_edges())
test_neg_u, test_neg_v = (
neg_u[neg_eids[:test_size]],
neg_v[neg_eids[:test_size]],
)
train_neg_u, train_neg_v = (
neg_u[neg_eids[test_size:]],
neg_v[neg_eids[test_size:]],
)
在训练时,你需要从原始图中移除测试集中的边。你可以通过dgl.remove_edges
来实现这一点。
注意
dgl.remove_edges
通过从原始图创建子图来工作,这会导致复制,因此对于大图可能会很慢。如果是这样,你可以将训练和测试图保存到磁盘,就像你在预处理时做的那样。
定义一个GraphSAGE模型
本教程构建了一个由两个
GraphSAGE 层组成的模型,每个层通过平均邻居信息来计算新的节点表示。DGL 提供了
dgl.nn.SAGEConv
,方便地创建一个 GraphSAGE 层。
from dgl.nn import SAGEConv
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, "mean")
self.conv2 = SAGEConv(h_feats, h_feats, "mean")
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
然后,模型通过计算两个相关节点表示之间的分数(例如使用MLP或点积函数)来预测边存在的概率,您将在下一节中看到这一点。
正图、负图和apply_edges
在之前的教程中,您已经学习了如何使用GNN计算节点表示。然而,链接预测需要您计算节点对的表示。
DGL建议您将节点对视为另一个图,因为您可以用边来描述一对节点。在链接预测中,您将有一个正图,由所有正例作为边组成,以及一个负图,由所有负例组成。正图和负图将包含与原始图相同的节点集。这使得在多个图之间传递节点特征以进行计算变得更加容易。正如您稍后将看到的,您可以直接将整个图上计算的节点表示输入到正图和负图中,以计算成对分数。
以下代码分别为训练集和测试集构建正图和负图。
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.num_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.num_nodes())
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.num_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.num_nodes())
将节点对视为图的好处是,你可以使用DGLGraph.apply_edges
方法,该方法方便地基于入射节点的特征和原始边特征(如果适用)计算新的边特征。
DGL 提供了一组优化的内置函数,用于基于原始节点/边特征计算新的边特征。例如,
dgl.function.u_dot_v
计算每条边的入射节点表示的点积。
import dgl.function as fn
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata["h"] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v("h", "h", "score"))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata["score"][:, 0]
如果情况复杂,你也可以编写自己的函数。 例如,以下模块通过连接相邻节点的特征并将其传递给MLP,为每条边生成一个标量分数。
class MLPPredictor(nn.Module):
def __init__(self, h_feats):
super().__init__()
self.W1 = nn.Linear(h_feats * 2, h_feats)
self.W2 = nn.Linear(h_feats, 1)
def apply_edges(self, edges):
"""
Computes a scalar score for each edge of the given graph.
Parameters
----------
edges :
Has three members ``src``, ``dst`` and ``data``, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.
Returns
-------
dict
A dictionary of new edge features.
"""
h = torch.cat([edges.src["h"], edges.dst["h"]], 1)
return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)}
def forward(self, g, h):
with g.local_scope():
g.ndata["h"] = h
g.apply_edges(self.apply_edges)
return g.edata["score"]
注意
内置函数在速度和内存方面都进行了优化。 我们建议尽可能使用内置函数。
注意
如果你已经阅读了消息传递教程,你会注意到参数apply_edges
的形式与update_all
中的消息函数完全相同。
Training loop
在定义了节点表示计算和边分数计算之后,您可以继续定义整体模型、损失函数和评估指标。
损失函数仅仅是二元交叉熵损失。
本教程中的评估指标是AUC。
model = GraphSAGE(train_g.ndata["feat"].shape[1], 16)
# You can replace DotPredictor with MLPPredictor.
# pred = MLPPredictor(16)
pred = DotPredictor()
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
)
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
).numpy()
return roc_auc_score(labels, scores)
训练循环如下:
注意
本教程不包括在验证集上的评估。在实际操作中,您应该根据验证集上的表现保存并评估最佳模型。
# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)
# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
# forward
h = model(train_g, train_g.ndata["feat"])
pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if e % 5 == 0:
print("In epoch {}, loss: {}".format(e, loss))
# ----------- 5. check results ------------------------ #
from sklearn.metrics import roc_auc_score
with torch.no_grad():
pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h)
print("AUC", compute_auc(pos_score, neg_score))
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'
In epoch 0, loss: 0.7021162509918213
In epoch 5, loss: 0.689262330532074
In epoch 10, loss: 0.6682888865470886
In epoch 15, loss: 0.6222915053367615
In epoch 20, loss: 0.5650363564491272
In epoch 25, loss: 0.5266429781913757
In epoch 30, loss: 0.4944186508655548
In epoch 35, loss: 0.4720936417579651
In epoch 40, loss: 0.4443744719028473
In epoch 45, loss: 0.4237181842327118
In epoch 50, loss: 0.40078991651535034
In epoch 55, loss: 0.3775203824043274
In epoch 60, loss: 0.3544590473175049
In epoch 65, loss: 0.33114877343177795
In epoch 70, loss: 0.30858904123306274
In epoch 75, loss: 0.285332590341568
In epoch 80, loss: 0.2629302740097046
In epoch 85, loss: 0.2413136214017868
In epoch 90, loss: 0.2202303111553192
In epoch 95, loss: 0.19989247620105743
AUC 0.8439091664607714
脚本的总运行时间: (0 分钟 4.066 秒)