训练用于图分类的GNN

在本教程结束时,您将能够

  • 加载一个DGL提供的图分类数据集。

  • 了解readout函数的作用。

  • 了解如何创建和使用图的小批量。

  • 构建一个基于GNN的图分类模型。

  • 在DGL提供的数据集上训练和评估模型。

(预计时间:18分钟)

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

使用GNN进行图分类的概述

图分类或回归需要一个模型来预测给定其节点和边特征的单个图的某些图级属性。分子属性预测是一个特定的应用。

本教程展示了如何为来自论文图神经网络有多强大的小数据集训练图分类模型。

加载数据

# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)

数据集是一组图,每个图都有节点特征和一个单一的标签。可以在dim_nfeatsgclasses属性中查看GINDataset对象的节点特征维度和可能的图类别数量。

print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)


from dgl.dataloading import GraphDataLoader
Node feature dimensionality: 3
Number of graph categories: 2

定义数据加载器

图分类数据集通常包含两种类型的元素:一组图及其图级标签。类似于图像分类任务,当数据集足够大时,我们需要使用小批量进行训练。当你训练图像分类或语言建模模型时,你将使用DataLoader来迭代数据集。在DGL中,你可以使用GraphDataLoader

你也可以使用在 torch.utils.data.sampler 中提供的各种数据集采样器。 例如,本教程创建了一个训练 GraphDataLoader 和 测试 GraphDataLoader,使用 SubsetRandomSampler 来告诉 PyTorch 仅从数据集的子集中进行采样。

from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)

你可以尝试遍历创建的 GraphDataLoader 并查看它返回的内容:

it = iter(train_dataloader)
batch = next(it)
print(batch)
[Graph(num_nodes=189, num_edges=983,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), tensor([0, 1, 0, 1, 0])]

由于dataset中的每个元素都有一个图和一个标签,GraphDataLoader每次迭代将返回两个对象。第一个元素是批处理图,第二个元素是一个简单的标签向量,表示小批次中每个图的类别。接下来,我们将讨论批处理图。

DGL中的批处理图

在每个小批量中,采样的图通过dgl.batch组合成一个更大的批量图。这个更大的批量图将所有原始图合并为单独连接的组件,节点和边的特征被连接起来。这个更大的图也是一个DGLGraph实例(因此你仍然可以像这里一样将其视为一个正常的DGLGraph对象)。然而,它包含了恢复原始图所需的信息,例如每个图元素的节点和边的数量。

batched_graph, labels = batch
print(
    "Number of nodes for each graph element in the batch:",
    batched_graph.batch_num_nodes(),
)
print(
    "Number of edges for each graph element in the batch:",
    batched_graph.batch_num_edges(),
)

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)
Number of nodes for each graph element in the batch: tensor([69, 13, 24, 20, 63])
Number of edges for each graph element in the batch: tensor([359,  63, 108,  94, 359])
The original graphs in the minibatch:
[Graph(num_nodes=69, num_edges=359,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=13, num_edges=63,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=24, num_edges=108,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=20, num_edges=94,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=63, num_edges=359,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})]

定义模型

本教程将构建一个两层的图卷积网络 (GCN)。每一层通过聚合邻居信息来计算新的节点表示。如果你已经阅读了 介绍,你会注意到两个 不同之处:

  • 由于任务是预测整个图的单一类别,而不是每个节点的类别,您需要聚合所有节点的表示,可能还包括边的表示,以形成图级别的表示。这个过程通常被称为readout。一个简单的选择是使用dgl.mean_nodes()来平均图的节点特征。

  • 模型的输入图将是由GraphDataLoader生成的批量图。DGL提供的读取函数可以处理批量图,以便它们为每个小批量元素返回一个表示。

from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")

Training Loop

训练循环使用GraphDataLoader对象遍历训练集并计算梯度,就像图像分类或语言建模一样。

# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["attr"].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["attr"].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print("Test accuracy:", num_correct / num_tests)
Test accuracy: 0.09865470852017937

接下来是什么

  • 请参阅GIN示例以了解端到端的图分类模型。

# Thumbnail credits: DGL
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'

脚本的总运行时间: (0 分钟 29.040 秒)

Gallery generated by Sphinx-Gallery