5.2 边分类/回归

(中文版)

有时你希望预测图边缘的属性。在这种情况下,你可能需要一个边缘分类/回归模型。

这里我们生成一个随机图用于边预测作为演示。

src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# make it symmetric
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# synthetic node and edge features, as well as edge labels
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# synthetic train-validation-test splits
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

Overview

从前面的部分你已经学会了如何使用多层GNN进行节点分类。同样的技术可以应用于计算任何节点的隐藏表示。然后可以从它们的关联节点的表示中推导出对边的预测。

在边上计算预测的最常见情况是将其表示为其相邻节点表示的参数化函数,并且可以选择性地表示边本身的特征。

模型实现与节点分类的差异

假设你使用上一节中的模型计算节点表示,你只需要编写另一个组件,使用apply_edges()方法来计算边预测。

例如,如果你想为边回归计算每条边的分数,以下代码计算每条边上入射节点表示的点积。

import dgl.function as fn
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']

也可以编写一个预测函数,使用MLP为每条边预测一个向量。这样的向量可以用于进一步的下游任务,例如作为分类分布的logits。

class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

Training loop

给定节点表示计算模型和边预测模型,我们可以轻松编写一个全图训练循环,在其中计算所有边的预测。

以下示例将上一节中的SAGE作为节点表示计算模型,并将DotPredictor作为边预测模型。

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return self.pred(g, h)

在这个例子中,我们还假设训练/验证/测试边集是通过边的布尔掩码来识别的。这个例子也不包括早期停止和模型保存。

node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(edge_pred_graph, node_features)
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

异构图

在异质图上进行边分类与在同质图上进行边分类并没有太大区别。如果你想在一种边类型上进行边分类,你只需要计算所有节点类型的节点表示,并使用apply_edges()方法对该边类型进行预测。

例如,要使DotProductPredictor在异构图的一种边类型上工作,您只需要在apply_edges方法中指定边类型。

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each edge type computed from
        # the GNN for heterogeneous graphs defined in the node classification
        # section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

你可以类似地编写一个HeteroMLPPredictor

class HeteroMLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h, etype):
        # h contains the node representations for each edge type computed from
        # the GNN for heterogeneous graphs defined in the node classification
        # section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
            graph.apply_edges(self.apply_edges, etype=etype)
            return graph.edges[etype].data['score']

预测单一边缘类型上每条边缘得分的端到端模型将如下所示:

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype)

使用模型只需将节点类型和特征的字典输入模型。

model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
label = hetero_graph.edges['click'].data['label']
train_mask = hetero_graph.edges['click'].data['train_mask']
node_features = {'user': user_feats, 'item': item_feats}

然后训练循环看起来几乎与同质图中的相同。例如,如果你想预测边类型 click 上的边标签,那么你可以简单地这样做

opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(hetero_graph, node_features, 'click')
    loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

预测异构图上现有边的类型

有时你可能想预测现有边属于哪种类型。

例如,给定 异构图示例, 你的任务是给定一条连接用户和项目的边,预测用户是否会 点击不喜欢 一个项目。

这是一个简化版的评分预测,这在推荐文献中很常见。

你可以使用异构图卷积网络来获取节点表示。例如,你仍然可以使用 之前定义的RGCN 来实现这一目的。

要预测边的类型,你可以简单地重新利用上面的HeteroDotProductPredictor,使其接受另一个图,该图仅包含一种边类型,该类型“合并”了所有要预测的边类型,并为每条边输出每种类型的分数。

在这个例子中,你需要一个包含两种节点类型的图 useritem,以及一个单一的边类型,它将所有来自 useritem 的边类型“合并”,即 clickdislike。 这可以使用以下语法方便地创建:

dec_graph = hetero_graph['user', :, 'item']

它返回一个包含节点类型useritem的异构图,以及一个结合了所有中间边类型的单一边类型,即clickdislike

由于上述语句还返回了原始边类型作为一个名为 dgl.ETYPE 的特征,我们可以将其用作标签。

edge_label = dec_graph.edata[dgl.ETYPE]

将上述图表作为边类型预测器模块的输入,您可以按如下方式编写预测器模块。

class HeteroMLPPredictor(nn.Module):
    def __init__(self, in_dims, n_classes):
        super().__init__()
        self.W = nn.Linear(in_dims * 2, n_classes)

    def apply_edges(self, edges):
        x = torch.cat([edges.src['h'], edges.dst['h']], 1)
        y = self.W(x)
        return {'score': y}

    def forward(self, graph, h):
        # h contains the node representations for each edge type computed from
        # the GNN for heterogeneous graphs defined in the node classification
        # section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

结合节点表示模块和边类型预测器模块的模型如下:

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroMLPPredictor(out_features, len(rel_names))
    def forward(self, g, x, dec_graph):
        h = self.sage(g, x)
        return self.pred(dec_graph, h)

训练循环可以简单地如下所示:

model = Model(10, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}

opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model(hetero_graph, node_features, dec_graph)
    loss = F.cross_entropy(logits, edge_label)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

DGL 提供了 图卷积矩阵补全 作为评分预测的示例,该示例通过预测异质图上现有边的类型来建模。模型实现文件 中的节点表示模块称为 GCMCLayer。边类型预测模块称为 BiDecoder。它们都比这里描述的场景更复杂。