DGL
latest

开始使用

  • 安装和设置
  • DGL 快速入门

先进材料

  • 🆕 使用GraphBolt进行GNN的随机训练
  • 用户指南
    • 第1章:图
    • 第2章:消息传递
    • 第3章:构建GNN模块
    • 第4章:图数据管道
    • 第5章:训练图神经网络
      • 5.1 节点分类/回归
      • 5.2 边分类/回归
      • 5.3 链接预测
      • 5.4 图分类
      • 5.5 使用边权重
    • 第6章:大型图上的随机训练
    • 第7章:分布式训练
    • 第8章:混合精度训练
  • 用户指南【包含过时信息】
  • 用户指南[过时的]
  • 🆕 教程: 图变换器
  • 教程: dgl.sparse
  • 在CPU上进行训练
  • 在多GPU上进行训练
  • 分布式训练
  • 使用DGL进行论文研究

API 参考

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • 🆕 dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • 用户自定义函数

注释

  • 为DGL做贡献
  • DGL 外部函数接口 (FFI)
  • 性能基准测试

杂项

  • 常见问题解答 (FAQ)
  • 环境变量
  • 资源
DGL
  • User Guide
  • Chapter 5: Training Graph Neural Networks
  • 5.3 Link Prediction
  • Edit on GitHub

5.3 链接预测

(中文版)

在其他一些设置中,您可能想要预测两个给定节点之间是否存在边。这样的任务被称为链接预测任务。

Overview

一个基于GNN的链接预测模型表示两个节点\(u\)和\(v\)之间连接的可能性,作为\(\boldsymbol{h}_u^{(L)}\)和\(\boldsymbol{h}_v^{(L)}\)的函数,这些节点表示是从多层GNN计算得出的。

\[y_{u,v} = \phi(\boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)})\]

在本节中,我们将\(y_{u,v}\)称为节点\(u\)和节点\(v\)之间的分数。

训练链接预测模型涉及比较通过边连接的节点之间的分数与任意一对节点之间的分数。例如,给定连接\(u\)和\(v\)的边,我们鼓励节点\(u\)和\(v\)之间的分数高于节点\(u\)与从任意噪声分布\(v' \sim P_n(v)\)中采样的节点\(v'\)之间的分数。这种方法被称为负采样。

有许多损失函数如果被最小化,可以实现上述行为。一个非详尽的列表包括:

  • 交叉熵损失: \(\mathcal{L} = - \log \sigma (y_{u,v}) - \sum_{v_i \sim P_n(v), i=1,\dots,k}\log \left[ 1 - \sigma (y_{u,v_i})\right]\)

  • BPR损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} - \log \sigma (y_{u,v} - y_{u,v_i})\)

  • 边际损失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} \max(0, M - y_{u, v} + y_{u, v_i})\), 其中 \(M\) 是一个常数超参数。

如果你知道什么是隐式反馈或 噪声对比估计,你可能会觉得这个想法很熟悉。

用于计算\(u\)和\(v\)之间分数的神经网络模型与上述描述的边缘回归模型相同。

这里是一个使用点积来计算边上分数的例子。

class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']

Training loop

因为我们的分数预测模型在图上运行,我们需要将负样本表示为另一个图。该图将包含所有负节点对作为边。

以下展示了将负例表示为图的示例。每条边 \((u,v)\) 获得 \(k\) 个负例 \((u,v_i)\),其中 \(v_i\) 是从均匀分布中采样的。

def construct_negative_graph(graph, k):
    src, dst = graph.edges()

    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))
    return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())

预测边缘分数的模型与边缘分类/回归的模型相同。

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, neg_g, x):
        h = self.sage(g, x)
        return self.pred(g, h), self.pred(neg_g, h)

训练循环随后反复构建负图并计算损失。

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

node_features = graph.ndata['feat']
n_features = node_features.shape[1]
k = 5
model = Model(n_features, 100, 100)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(graph, k)
    pos_score, neg_score = model(graph, negative_graph, node_features)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

训练后,可以通过以下方式获取节点表示

node_embeddings = model.sage(graph, node_features)

有多种使用节点嵌入的方法。示例包括训练下游分类器,或进行最近邻搜索或最大内积搜索以进行相关实体推荐。

异构图

在异质图上的链接预测与在同质图上的链接预测并没有太大不同。以下假设我们正在预测一种边类型,并且很容易将其扩展到多种边类型。

例如,您可以重用HeteroDotProductPredictor above 来计算链接预测中某一边类型的边的分数。

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

为了执行负采样,可以为执行链接预测的边类型构建一个负图。

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

该模型与异构图上的边分类模型略有不同,因为您需要指定执行链接预测的边类型。

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

训练循环与同构图类似。

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

k = 5
model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())
Previous Next

© Copyright 2018, DGL Team. Revision 2ee440a6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
2.2.x
2.1.x
2.0.x
1.1.x
1.0.x
0.9.x
0.8.x
0.7.x
0.6.x
Downloads
On Read the Docs
Project Home
Builds