节点分类

本教程展示了如何在由开放图基准(OGB)提供的ogbn-arxiv上训练一个多层GraphSAGE进行节点分类。该数据集包含大约17万个节点和100万条边。

Open In Colab GitHub

在本教程结束时,您将能够

  • 使用DGL的邻居采样组件在单个GPU上训练一个用于节点分类的GNN模型。

安装DGL包

[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://www.dgl.ai/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages/dgl-2.3-py3.8-linux-x86_64.egg (2.3)
Requirement already satisfied: numpy>=1.14.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.10.1)
Requirement already satisfied: networkx>=2.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (3.1)
Requirement already satisfied: requests>=2.19.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (4.66.4)
Requirement already satisfied: psutil>=5.8.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (5.9.8)
Requirement already satisfied: torchdata>=0.5.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (0.7.1)
Requirement already satisfied: pandas in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from dgl) (2.0.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2024.2.2)
Requirement already satisfied: torch>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: tzdata>=2022.1 in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: six>=1.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.14.0)
Requirement already satisfied: typing-extensions in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.11.0)
Requirement already satisfied: sympy in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.101)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.10.3.66)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.9.0.58)
Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.2.10.91)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.4.0.1)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.4.91)
Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.14.3)
Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.91)
Requirement already satisfied: triton==2.0.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: setuptools in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (69.5.1)
Requirement already satisfied: wheel in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (0.43.0)
Requirement already satisfied: cmake in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (3.29.3)
Requirement already satisfied: lit in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (18.1.4)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
DGL installed!

加载数据集

ogbn-arxiv 已经在 GraphBolt 中作为 BuiltinDataset 准备好了。

[2]:
dataset = gb.BuiltinDataset("ogbn-arxiv-seeds").load()
Downloading datasets/ogbn-arxiv-seeds.zip from https://data.dgl.ai/dataset/graphbolt/ogbn-arxiv-seeds.zip...
Extracting file to datasets
The dataset is already preprocessed.

数据集由图、特征和任务组成。您可以从任务中获取训练-验证-测试集。种子节点和相应的标签已经存储在每个训练-验证-测试集中。其他元数据,如类别数量,也存储在任务中。在这个数据集中,只有一个任务:node classification

[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
task_name = dataset.tasks[0].metadata["name"]
num_classes = dataset.tasks[0].metadata["num_classes"]
print(f"Task: {task_name}. Number of classes: {num_classes}")
Task: node_classification. Number of classes: 40

DGL如何处理计算依赖¶

单个节点的消息传递的计算依赖可以描述为一系列消息流图(MFG)。

DGL Computation

在DGL中定义邻居采样器和数据加载器

DGL 提供了工具来以小批量迭代数据集,同时生成计算依赖关系,以使用上述的 MFGs 计算它们的输出。对于节点分类,你可以使用 dgl.graphbolt.DataLoader 来迭代数据集。它接受一个数据管道,该管道生成节点及其标签的小批量,为每个节点采样邻居,并以 MFGs 的形式生成计算依赖关系。还支持特征获取、块创建和复制到目标设备。所有这些操作都被分成数据管道中的不同阶段,以便你可以通过插入自己的操作来自定义数据管道。

假设每个节点将在每一层从4个邻居收集消息。定义数据加载器和邻居采样器的代码将如下所示。

[4]:
def create_dataloader(itemset, shuffle):
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)
    datapipe = datapipe.copy_to(device)
    datapipe = datapipe.sample_neighbor(graph, [4, 4])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

你可以遍历数据加载器,并且会生成一个MiniBatch对象。

[5]:
data = next(iter(create_dataloader(train_set, shuffle=True)))
print(data)
MiniBatch(seeds=tensor([ 78324,  71698, 128486,  ..., 103646, 147000,   3646]),
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    1,    3,  ..., 6095, 6099, 6099], dtype=torch.int32),
                                                                         indices=tensor([1024, 1025, 1026,  ..., 7023, 7025, 7026], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 78324,  71698, 128486,  ..., 161557, 165481,  67906]),
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([ 78324,  71698, 128486,  ...,  62400, 119968,  32427]),
                            ),
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    1,    3,  ..., 2157, 2159, 2160], dtype=torch.int32),
                                                                         indices=tensor([1024, 1025, 1026,  ..., 3145, 3146, 3147], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 78324,  71698, 128486,  ...,  62400, 119968,  32427]),
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([ 78324,  71698, 128486,  ..., 103646, 147000,   3646]),
                            )],
          node_features={'feat': tensor([[-0.0984, -0.0103, -0.0498,  ...,  0.1173,  0.1543, -0.1799],
                                [-0.1082,  0.0946, -0.1910,  ...,  0.0996,  0.1690,  0.1150],
                                [-0.2827,  0.0679, -0.1128,  ...,  0.1666,  0.0382, -0.0793],
                                ...,
                                [-0.1164, -0.0217, -0.0913,  ...,  0.1816, -0.1474, -0.1094],
                                [-0.1101,  0.0444, -0.1484,  ...,  0.1235, -0.0812, -0.0271],
                                [-0.1057,  0.1120, -0.1044,  ...,  0.1743, -0.1070, -0.2529]])},
          labels=tensor([28,  8,  9,  ..., 16, 28,  5]),
          input_nodes=tensor([ 78324,  71698, 128486,  ..., 161557, 165481,  67906]),
          indexes=None,
          edge_features=[{},
                        {}],
          compacted_seeds=None,
          blocks=[Block(num_src_nodes=7027, num_dst_nodes=3148, num_edges=6099),
                 Block(num_src_nodes=3148, num_dst_nodes=1024, num_edges=2160)],
       )

你可以从MFGs中获取输入节点的ID。

[6]:
mfgs = data.blocks
input_nodes = mfgs[0].srcdata[dgl.NID]
print(f"Input nodes: {input_nodes}.")
Input nodes: tensor([ 78324,  71698, 128486,  ..., 161557, 165481,  67906]).

定义模型

让我们考虑训练一个带有邻居采样的2层GraphSAGE。模型可以写成如下形式:

[7]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv


class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h = self.conv1(mfgs[0], x)
        h = F.relu(h)
        h = self.conv2(mfgs[1], h)
        return h


in_size = feature.size("node", None, "feat")[0]
model = Model(in_size, 64, num_classes).to(device)

定义训练循环

以下初始化模型并定义优化器。

[8]:
opt = torch.optim.Adam(model.parameters())

在计算模型选择的验证分数时,通常你也可以进行邻居采样。我们可以直接重用我们的create_dataloader函数来为训练和验证创建两个独立的数据加载器。

[9]:
train_dataloader = create_dataloader(train_set, shuffle=True)
valid_dataloader = create_dataloader(valid_set, shuffle=False)

import sklearn.metrics

以下是一个在每个epoch执行验证的训练循环。它还将具有最佳验证准确率的模型保存到文件中。

[10]:
from tqdm.auto import tqdm

for epoch in range(10):
    model.train()

    with tqdm(train_dataloader) as tq:
        for step, data in enumerate(tq):
            x = data.node_features["feat"]
            labels = data.labels

            predictions = model(data.blocks, x)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(
                labels.cpu().numpy(),
                predictions.argmax(1).detach().cpu().numpy(),
            )

            tq.set_postfix(
                {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
                refresh=False,
            )

    model.eval()

    predictions = []
    labels = []
    with tqdm(valid_dataloader) as tq, torch.no_grad():
        for data in tq:
            x = data.node_features["feat"]
            labels.append(data.labels.cpu().numpy())
            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
Epoch 0 Validation Accuracy 0.4000469814423303
Epoch 1 Validation Accuracy 0.49991610456726737
Epoch 2 Validation Accuracy 0.5320648343904157
Epoch 3 Validation Accuracy 0.5523675291117152
Epoch 4 Validation Accuracy 0.5591127219034195
Epoch 5 Validation Accuracy 0.5694486392160811
Epoch 6 Validation Accuracy 0.5770327863351119
Epoch 7 Validation Accuracy 0.5784422296050203
Epoch 8 Validation Accuracy 0.5827376757609316
Epoch 9 Validation Accuracy 0.5855230041276553

结论

在本教程中,您已经学习了如何使用邻居采样训练多层GraphSAGE。