使用DGL进行节点分类

GNNs 是用于图上的许多机器学习任务的强大工具。在本入门教程中,您将学习使用 GNNs 进行节点分类的基本工作流程,即预测图中节点的类别。

通过完成本教程,您将能够

  • 加载一个DGL提供的数据集。

  • 使用DGL提供的神经网络模块构建一个GNN模型。

  • 在CPU或GPU上训练和评估用于节点分类的GNN模型。

本教程假设您有使用PyTorch构建神经网络的经验。

(预计时间:13分钟)

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

使用GNN进行节点分类的概述

在图数据上最受欢迎和广泛采用的任务之一是节点分类,其中模型需要预测每个节点的真实类别。在图神经网络之前,许多提出的方法仅使用连接性(如DeepWalk或node2vec),或者简单地结合连接性和节点自身的特征。相比之下,GNNs提供了一个机会,通过结合局部邻域的连接性和特征来获得节点表示。

Kipf 等人,是一个将节点分类问题表述为半监督节点分类任务的例子。仅借助一小部分标记节点的帮助,图神经网络(GNN)可以准确预测其他节点的类别。

本教程将展示如何在Cora数据集上构建一个用于半监督节点分类的GNN,该数据集是一个以论文为节点、引用为边的引用网络。任务是预测给定论文的类别。每个论文节点包含一个词频向量作为其特征,这些特征被归一化,使其总和为一,如论文的第5.2节所述。

加载Cora数据集

dataset = dgl.data.CoraGraphDataset()
print(f"Number of categories: {dataset.num_classes}")
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Number of categories: 7

一个DGL数据集对象可能包含一个或多个图。本教程中使用的Cora数据集仅包含一个单独的图。

g = dataset[0]

DGL图可以在两个类似字典的属性中存储节点特征和边特征,这两个属性称为ndataedata。在DGL Cora数据集中,图包含以下节点特征:

  • train_mask: 一个布尔张量,指示节点是否在训练集中。

  • val_mask: 一个布尔张量,指示节点是否在验证集中。

  • test_mask: 一个布尔张量,指示节点是否在测试集中。

  • label: 地面真实节点类别。

  • feat: 节点特征。

print("Node features")
print(g.ndata)
print("Edge features")
print(g.edata)
Node features
{'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'val_mask': tensor([False, False, False,  ..., False, False, False])}
Edge features
{}

定义图卷积网络 (GCN)

本教程将构建一个两层的图卷积网络 (GCN)。每一层通过聚合邻居信息来计算新的节点表示。

要构建多层GCN,你可以简单地堆叠dgl.nn.GraphConv模块,这些模块继承自torch.nn.Module

from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


# Create the model with given dimensions
model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)

DGL 提供了许多流行的邻居聚合模块的实现。你可以用一行代码轻松调用它们。

训练GCN

训练这个GCN类似于训练其他PyTorch神经网络。

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print(
                f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
            )


model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)
In epoch 0, loss: 1.946, val acc: 0.102 (best 0.102), test acc: 0.106 (best 0.106)
In epoch 5, loss: 1.902, val acc: 0.486 (best 0.514), test acc: 0.478 (best 0.533)
In epoch 10, loss: 1.831, val acc: 0.644 (best 0.644), test acc: 0.650 (best 0.650)
In epoch 15, loss: 1.732, val acc: 0.648 (best 0.654), test acc: 0.649 (best 0.653)
In epoch 20, loss: 1.610, val acc: 0.650 (best 0.654), test acc: 0.655 (best 0.653)
In epoch 25, loss: 1.466, val acc: 0.662 (best 0.662), test acc: 0.669 (best 0.669)
In epoch 30, loss: 1.304, val acc: 0.690 (best 0.690), test acc: 0.685 (best 0.685)
In epoch 35, loss: 1.133, val acc: 0.692 (best 0.698), test acc: 0.703 (best 0.696)
In epoch 40, loss: 0.961, val acc: 0.712 (best 0.712), test acc: 0.716 (best 0.716)
In epoch 45, loss: 0.799, val acc: 0.726 (best 0.726), test acc: 0.730 (best 0.730)
In epoch 50, loss: 0.654, val acc: 0.730 (best 0.730), test acc: 0.740 (best 0.740)
In epoch 55, loss: 0.530, val acc: 0.742 (best 0.742), test acc: 0.746 (best 0.743)
In epoch 60, loss: 0.429, val acc: 0.744 (best 0.744), test acc: 0.752 (best 0.751)
In epoch 65, loss: 0.347, val acc: 0.742 (best 0.744), test acc: 0.750 (best 0.751)
In epoch 70, loss: 0.282, val acc: 0.744 (best 0.744), test acc: 0.752 (best 0.751)
In epoch 75, loss: 0.231, val acc: 0.752 (best 0.752), test acc: 0.755 (best 0.755)
In epoch 80, loss: 0.191, val acc: 0.752 (best 0.754), test acc: 0.756 (best 0.756)
In epoch 85, loss: 0.159, val acc: 0.752 (best 0.754), test acc: 0.756 (best 0.756)
In epoch 90, loss: 0.134, val acc: 0.752 (best 0.754), test acc: 0.757 (best 0.756)
In epoch 95, loss: 0.113, val acc: 0.752 (best 0.754), test acc: 0.759 (best 0.756)

在GPU上进行训练

在GPU上进行训练需要将模型和图都放到GPU上,使用to方法,类似于在PyTorch中所做的操作。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

接下来是什么?

# Thumbnail credits: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'

脚本的总运行时间: (0 分钟 4.667 秒)

Gallery generated by Sphinx-Gallery