注意
Go to the end to download the full example code
关系图卷积网络
作者: Lingfan Yu, Mufei Li, Zheng Zhang
警告
The tutorial aims at gaining insights into the paper, with code as a mean of explanation. The implementation thus is NOT optimized for running efficiency. For recommended implementation, please refer to the official examples.
在本教程中,您将学习如何实现关系图卷积网络(R-GCN)。这种类型的网络是为了将GCN推广到处理知识库中实体之间的不同关系而做出的一种努力。要了解更多关于R-GCN背后的研究,请参阅使用图卷积网络建模关系数据
简单的图卷积网络(GCN)利用数据集的结构信息(即图连接性)来改进节点表示的提取。图的边保持为无类型。
知识图谱由一组三元组组成,形式为主语、关系、宾语。因此,边编码了重要信息,并且有自己的嵌入需要学习。此外,任何给定的一对之间可能存在多条边。
R-GCN简介
在统计关系学习 (SRL) 中,有两个基本任务:
实体分类 - 在这里你为实体分配类型和分类属性。
链接预测 - 在这里你恢复缺失的三元组。
在这两种情况下,缺失的信息预计可以从图的邻域结构中恢复。例如,之前引用的R-GCN论文提供了以下示例。知道Mikhail Baryshnikov在Vaganova Academy接受教育,意味着Mikhail Baryshnikov应该具有person标签,并且三元组(Mikhail Baryshnikov, lived in, Russia)必须属于知识图谱。
R-GCN 使用常见的图卷积网络解决了这两个问题。它通过多边编码扩展以计算实体的嵌入,但具有不同的下游处理。
实体分类是通过在实体(节点)的最终嵌入上附加一个softmax分类器来完成的。训练是通过标准交叉熵损失进行的。
链接预测是通过使用参数化的评分函数,利用自动编码器架构重建边来完成的。训练使用负采样。
本教程专注于第一个任务,实体分类,以展示如何生成实体表示。完整代码可以在DGL的Github仓库中找到。
R-GCN的关键思想
回想一下,在GCN中,每个节点\(i\)在第\((l+1)^{th}\)层的隐藏表示是通过以下方式计算的:
其中 \(c_i\) 是一个归一化常数。
R-GCN和GCN之间的关键区别在于,在R-GCN中,边可以表示不同的关系。在GCN中,方程\((1)\)中的权重\(W^{(l)}\)在第\(l\)层中由所有边共享。相比之下,在R-GCN中,不同的边类型使用不同的权重,只有相同关系类型\(r\)的边与相同的投影权重\(W_r^{(l)}\)相关联。
因此,R-GCN中第\((l+1)^{th}\)层实体的隐藏表示可以表示为以下公式:
其中 \(N_i^r\) 表示节点 \(i\) 在关系 \(r\in R\) 下的邻居索引集合,\(c_{i,r}\) 是一个归一化常数。在实体分类中,R-GCN 论文使用 \(c_{i,r}=|N_i^r|\)。
直接应用上述方程的问题是参数数量的快速增长,特别是在高度多关系数据的情况下。为了减少模型参数的大小并防止过拟合,原始论文提出了使用基分解的方法。
因此,权重 \(W_r^{(l)}\) 是基础变换 \(V_b^{(l)}\) 与系数 \(a_{rb}^{(l)}\) 的线性组合。 基础的数量 \(B\) 远小于知识库中关系的数量。
注意
另一种权重正则化方法,块分解,已在链接预测中实现。
在DGL中实现R-GCN
R-GCN模型由多个R-GCN层组成。第一个R-GCN层也作为输入层,接收与节点实体相关联的特征(例如描述文本)并将其投影到隐藏空间。在本教程中,我们仅使用实体ID作为实体特征。
R-GCN 层
对于每个节点,R-GCN层执行以下步骤:
使用节点表示和与边类型相关的权重矩阵计算传出消息(消息函数)
聚合传入的消息并生成新的节点表示(减少和应用函数)
以下代码是R-GCN隐藏层的定义。
注意
每种关系类型都与不同的权重相关联。因此,完整的权重矩阵具有三个维度:关系、输入特征、输出特征。
注意
这展示了如何从头开始实现一个R-GCN。DGL提供了一个更高效的builtin R-GCN layer module
。
import os
os.environ["DGLBACKEND"] = "pytorch"
from functools import partial
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
class RGCNLayer(nn.Module):
def __init__(
self,
in_feat,
out_feat,
num_rels,
num_bases=-1,
bias=None,
activation=None,
is_input_layer=False,
):
super(RGCNLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.is_input_layer = is_input_layer
# sanity check
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# weight bases in equation (3)
self.weight = nn.Parameter(
torch.Tensor(self.num_bases, self.in_feat, self.out_feat)
)
if self.num_bases < self.num_rels:
# linear combination coefficients in equation (3)
self.w_comp = nn.Parameter(
torch.Tensor(self.num_rels, self.num_bases)
)
# add bias
if self.bias:
self.bias = nn.Parameter(torch.Tensor(out_feat))
# init trainable parameters
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(
self.w_comp, gain=nn.init.calculate_gain("relu")
)
if self.bias:
nn.init.xavier_uniform_(
self.bias, gain=nn.init.calculate_gain("relu")
)
def forward(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases (equation (3))
weight = self.weight.view(
self.in_feat, self.num_bases, self.out_feat
)
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat
)
else:
weight = self.weight
if self.is_input_layer:
def message_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data[dgl.ETYPE] * self.in_feat + edges.src["id"]
return {"msg": embed[index] * edges.data["norm"]}
else:
def message_func(edges):
w = weight[edges.data[dgl.ETYPE]]
msg = torch.bmm(edges.src["h"].unsqueeze(1), w).squeeze()
msg = msg * edges.data["norm"]
return {"msg": msg}
def apply_func(nodes):
h = nodes.data["h"]
if self.bias:
h = h + self.bias
if self.activation:
h = self.activation(h)
return {"h": h}
g.update_all(message_func, fn.sum(msg="msg", out="h"), apply_func)
完整的R-GCN模型定义
class Model(nn.Module):
def __init__(
self,
num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
):
super(Model, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
# create rgcn layers
self.build_model()
# create initial features
self.features = self.create_features()
def build_model(self):
self.layers = nn.ModuleList()
# input to hidden
i2h = self.build_input_layer()
self.layers.append(i2h)
# hidden to hidden
for _ in range(self.num_hidden_layers):
h2h = self.build_hidden_layer()
self.layers.append(h2h)
# hidden to output
h2o = self.build_output_layer()
self.layers.append(h2o)
# initialize feature for each node
def create_features(self):
features = torch.arange(self.num_nodes)
return features
def build_input_layer(self):
return RGCNLayer(
self.num_nodes,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
is_input_layer=True,
)
def build_hidden_layer(self):
return RGCNLayer(
self.h_dim,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
)
def build_output_layer(self):
return RGCNLayer(
self.h_dim,
self.out_dim,
self.num_rels,
self.num_bases,
activation=partial(F.softmax, dim=1),
)
def forward(self, g):
if self.features is not None:
g.ndata["id"] = self.features
for layer in self.layers:
layer(g)
return g.ndata.pop("h")
处理数据集
本教程使用来自R-GCN论文的应用信息学和形式描述方法研究所(AIFB)数据集。
# load graph data
dataset = dgl.data.rdf.AIFBDataset()
g = dataset[0]
category = dataset.predict_category
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop("label")
num_rels = len(g.canonical_etypes)
num_classes = dataset.num_classes
# normalization factor
for cetype in g.canonical_etypes:
g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category)
Done loading data from cached files.
创建图形和模型
# configurations
n_hidden = 16 # number of hidden units
n_bases = -1 # use number of relations as number of bases
n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25 # epochs to train
lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient
# create graph
g = dgl.to_homogeneous(g, edata=["norm"])
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create model
model = Model(
g.num_nodes(),
n_hidden,
num_classes,
num_rels,
num_bases=n_bases,
num_hidden_layers=n_hidden_layers,
)
Training loop
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)
print("start training...")
model.train()
for epoch in range(n_epochs):
optimizer.zero_grad()
logits = model.forward(g)
logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward()
optimizer.step()
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
train_acc = train_acc.item() / len(train_idx)
val_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx])
val_acc = val_acc.item() / len(test_idx)
print(
"Epoch {:05d} | ".format(epoch)
+ "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()
)
+ "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item()
)
)
start training...
Epoch 00000 | Train Accuracy: 0.2857 | Train Loss: 1.3858 | Validation Accuracy: 0.3611 | Validation loss: 1.3856
Epoch 00001 | Train Accuracy: 0.9214 | Train Loss: 1.3555 | Validation Accuracy: 0.9444 | Validation loss: 1.3616
Epoch 00002 | Train Accuracy: 0.9357 | Train Loss: 1.3086 | Validation Accuracy: 0.9444 | Validation loss: 1.3230
Epoch 00003 | Train Accuracy: 0.9357 | Train Loss: 1.2449 | Validation Accuracy: 0.9167 | Validation loss: 1.2693
Epoch 00004 | Train Accuracy: 0.9357 | Train Loss: 1.1717 | Validation Accuracy: 0.9167 | Validation loss: 1.2051
Epoch 00005 | Train Accuracy: 0.9357 | Train Loss: 1.1008 | Validation Accuracy: 0.9167 | Validation loss: 1.1406
Epoch 00006 | Train Accuracy: 0.9357 | Train Loss: 1.0401 | Validation Accuracy: 0.9167 | Validation loss: 1.0842
Epoch 00007 | Train Accuracy: 0.9357 | Train Loss: 0.9914 | Validation Accuracy: 0.9167 | Validation loss: 1.0380
Epoch 00008 | Train Accuracy: 0.9357 | Train Loss: 0.9528 | Validation Accuracy: 0.9444 | Validation loss: 1.0004
Epoch 00009 | Train Accuracy: 0.9357 | Train Loss: 0.9222 | Validation Accuracy: 0.9444 | Validation loss: 0.9696
Epoch 00010 | Train Accuracy: 0.9429 | Train Loss: 0.8972 | Validation Accuracy: 0.9444 | Validation loss: 0.9439
Epoch 00011 | Train Accuracy: 0.9500 | Train Loss: 0.8761 | Validation Accuracy: 0.9444 | Validation loss: 0.9222
Epoch 00012 | Train Accuracy: 0.9500 | Train Loss: 0.8582 | Validation Accuracy: 0.9722 | Validation loss: 0.9035
Epoch 00013 | Train Accuracy: 0.9500 | Train Loss: 0.8435 | Validation Accuracy: 0.9722 | Validation loss: 0.8876
Epoch 00014 | Train Accuracy: 0.9500 | Train Loss: 0.8320 | Validation Accuracy: 0.9722 | Validation loss: 0.8745
Epoch 00015 | Train Accuracy: 0.9500 | Train Loss: 0.8234 | Validation Accuracy: 0.9722 | Validation loss: 0.8642
Epoch 00016 | Train Accuracy: 0.9500 | Train Loss: 0.8172 | Validation Accuracy: 0.9722 | Validation loss: 0.8562
Epoch 00017 | Train Accuracy: 0.9500 | Train Loss: 0.8127 | Validation Accuracy: 0.9722 | Validation loss: 0.8501
Epoch 00018 | Train Accuracy: 0.9500 | Train Loss: 0.8094 | Validation Accuracy: 0.9722 | Validation loss: 0.8454
Epoch 00019 | Train Accuracy: 0.9500 | Train Loss: 0.8068 | Validation Accuracy: 0.9722 | Validation loss: 0.8417
Epoch 00020 | Train Accuracy: 0.9500 | Train Loss: 0.8046 | Validation Accuracy: 0.9722 | Validation loss: 0.8388
Epoch 00021 | Train Accuracy: 0.9500 | Train Loss: 0.8025 | Validation Accuracy: 0.9722 | Validation loss: 0.8363
Epoch 00022 | Train Accuracy: 0.9500 | Train Loss: 0.8005 | Validation Accuracy: 0.9722 | Validation loss: 0.8343
Epoch 00023 | Train Accuracy: 0.9500 | Train Loss: 0.7983 | Validation Accuracy: 0.9722 | Validation loss: 0.8326
Epoch 00024 | Train Accuracy: 0.9500 | Train Loss: 0.7959 | Validation Accuracy: 0.9722 | Validation loss: 0.8312
第二个任务,链接预测
到目前为止,你已经了解了如何使用DGL通过R-GCN模型实现实体分类。在知识库设置中,R-GCN生成的表示可用于揭示节点之间的潜在关系。在R-GCN论文中,作者将R-GCN生成的实体表示输入到DistMult预测模型中,以预测可能的关系。
实现方式与这里展示的类似,但在R-GCN层之上额外叠加了一个DistMult层。您可以在我们的Github Python代码示例中找到使用R-GCN进行链接预测的完整实现。
脚本的总运行时间: (0 分钟 3.264 秒)