5.1 节点分类/回归

(中文版)

图神经网络中最受欢迎和广泛采用的任务之一是节点分类,其中训练/验证/测试集中的每个节点都被分配了一个来自预定义类别的真实类别。节点回归类似,其中训练/验证/测试集中的每个节点都被分配了一个真实的数值。

Overview

为了对节点进行分类,图神经网络执行在第2章:消息传递中讨论的消息传递,以利用节点自身的特征,还包括其邻近节点和边的特征。消息传递可以重复多轮,以整合来自更大范围邻域的信息。

编写神经网络模型

DGL 提供了一些内置的图卷积模块,可以执行一轮消息传递。在本指南中,我们选择 dgl.nn.pytorch.SAGEConv(也可在 MXNet 和 Tensorflow 中使用),这是用于 GraphSAGE 的图卷积模块。

通常对于图上的深度学习模型,我们需要一个多层的图神经网络,其中我们进行多轮的消息传递。这可以通过如下方式堆叠图卷积模块来实现。

# Contruct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

请注意,您不仅可以将上述模型用于节点分类, 还可以获取隐藏节点表示以用于其他下游任务,例如 5.2 边分类/回归, 5.3 链接预测, 或 5.4 图分类.

有关内置图卷积模块的完整列表,请参阅apinn

有关DGL神经网络模块如何工作以及如何使用消息传递编写自定义神经网络模块的更多详细信息,请参阅第3章:构建GNN模块中的示例。

Training loop

在全图上进行训练仅涉及上述模型的前向传播,并通过将预测结果与训练节点上的真实标签进行比较来计算损失。

本节使用DGL内置数据集 dgl.data.CiteseerGraphDataset 来 展示一个训练循环。节点特征 和标签存储在其图实例中,并且 训练-验证-测试分割也作为布尔 掩码存储在图中。这与你在第4章:图数据管道中看到的类似。

node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

以下是通过准确率评估模型的示例。

def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

然后你可以按照以下方式编写我们的训练循环。

model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.

GraphSAGE 提供了一个端到端的同构图节点分类示例。 您可以看到相应的模型实现位于示例中的 GraphSAGE 类中,具有可调整的层数、 dropout概率,以及可定制的聚合函数和 非线性函数。

Heterogeneous graph

如果你的图是异质的,你可能希望从所有边类型的邻居那里收集消息。你可以使用模块 dgl.nn.pytorch.HeteroGraphConv(在MXNet和Tensorflow中也可用) 来在所有边类型上执行消息传递,然后为每种边类型组合不同的图卷积模块。

以下代码将定义一个异构图卷积模块,该模块首先对每种边类型执行单独的图卷积,然后将每种边类型的消息聚合求和作为所有节点类型的最终结果。

# Define a Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

dgl.nn.HeteroGraphConv 接受一个节点类型和节点特征张量的字典作为输入,并返回另一个节点类型和节点特征的字典。

因此,假设我们在异构图示例中拥有用户和物品特征。

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

可以简单地执行前向传播如下:

node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']

训练循环与同构图相同,只是现在你有一个节点表示的字典,从中计算预测。例如,如果你只预测user节点,你可以从返回的字典中提取user节点嵌入:

opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # forward propagation by using all nodes and extracting the user embeddings
    logits = model(hetero_graph, node_features)['user']
    # compute loss
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # Compute validation accuracy.  Omitted in this example.
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in the example.

DGL 提供了一个端到端的示例,用于节点分类的 RGCN。 你可以在模型实现文件中的 RelGraphConvLayer 中看到异构图卷积的定义。