5.3 链接预测
在其他一些设置中,您可能想要预测两个给定节点之间是否存在边。这样的任务被称为链接预测任务。
Overview
一个基于GNN的链接预测模型表示两个节点\(u\)和\(v\)之间连接的可能性,作为\(\boldsymbol{h}_u^{(L)}\)和\(\boldsymbol{h}_v^{(L)}\)的函数,这些节点表示是从多层GNN计算得出的。
在本节中,我们将\(y_{u,v}\)称为节点\(u\)和节点\(v\)之间的分数。
训练链接预测模型涉及比较通过边连接的节点之间的分数与任意一对节点之间的分数。例如,给定连接\(u\)和\(v\)的边,我们鼓励节点\(u\)和\(v\)之间的分数高于节点\(u\)与从任意噪声分布\(v' \sim P_n(v)\)中采样的节点\(v'\)之间的分数。这种方法被称为负采样。
有许多损失函数如果被最小化,可以实现上述行为。一个非详尽的列表包括:
交叉熵损失: \(\mathcal{L} = - \log \sigma (y_{u,v}) - \sum_{v_i \sim P_n(v), i=1,\dots,k}\log \left[ 1 - \sigma (y_{u,v_i})\right]\)
BPR损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} - \log \sigma (y_{u,v} - y_{u,v_i})\)
边际损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} \max(0, M - y_{u, v} + y_{u, v_i})\), 其中 \(M\) 是一个常数超参数。
如果你知道什么是隐式反馈或 噪声对比估计,你可能会觉得这个想法很熟悉。
用于计算\(u\)和\(v\)之间分数的神经网络模型与上述描述的边缘回归模型相同。
这里是一个使用点积来计算边上分数的例子。
class DotProductPredictor(nn.Module):
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
Training loop
因为我们的分数预测模型在图上运行,我们需要将负样本表示为另一个图。该图将包含所有负节点对作为边。
以下展示了将负例表示为图的示例。每条边 \((u,v)\) 获得 \(k\) 个负例 \((u,v_i)\),其中 \(v_i\) 是从均匀分布中采样的。
def construct_negative_graph(graph, k):
src, dst = graph.edges()
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))
return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())
预测边缘分数的模型与边缘分类/回归的模型相同。
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.sage = SAGE(in_features, hidden_features, out_features)
self.pred = DotProductPredictor()
def forward(self, g, neg_g, x):
h = self.sage(g, x)
return self.pred(g, h), self.pred(neg_g, h)
训练循环随后反复构建负图并计算损失。
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
node_features = graph.ndata['feat']
n_features = node_features.shape[1]
k = 5
model = Model(n_features, 100, 100)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(graph, k)
pos_score, neg_score = model(graph, negative_graph, node_features)
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
训练后,可以通过以下方式获取节点表示
node_embeddings = model.sage(graph, node_features)
有多种使用节点嵌入的方法。示例包括训练下游分类器,或进行最近邻搜索或最大内积搜索以进行相关实体推荐。
异构图
在异质图上的链接预测与在同质图上的链接预测并没有太大不同。以下假设我们正在预测一种边类型,并且很容易将其扩展到多种边类型。
例如,您可以重用HeteroDotProductPredictor
above
来计算链接预测中某一边类型的边的分数。
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h contains the node representations for each node type computed from
# the GNN defined in the previous section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
为了执行负采样,可以为执行链接预测的边类型构建一个负图。
def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
该模型与异构图上的边分类模型略有不同,因为您需要指定执行链接预测的边类型。
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
训练循环与同构图类似。
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
k = 5
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())