节点分类
本教程展示了如何在由开放图基准(OGB)提供的ogbn-arxiv
上训练一个多层GraphSAGE进行节点分类。该数据集包含大约17万个节点和100万条边。
在本教程结束时,您将能够
使用DGL的邻居采样组件在单个GPU上训练一个用于节点分类的GNN模型。
安装DGL包
[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://www.dgl.ai/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html
try:
import dgl
import dgl.graphbolt as gb
installed = True
except ImportError as error:
installed = False
print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages/dgl-2.3-py3.8-linux-x86_64.egg (2.3)
Requirement already satisfied: numpy>=1.14.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.10.1)
Requirement already satisfied: networkx>=2.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (3.1)
Requirement already satisfied: requests>=2.19.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (4.66.4)
Requirement already satisfied: psutil>=5.8.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (5.9.8)
Requirement already satisfied: torchdata>=0.5.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (0.7.1)
Requirement already satisfied: pandas in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from dgl) (2.0.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2024.2.2)
Requirement already satisfied: torch>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: tzdata>=2022.1 in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: six>=1.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.14.0)
Requirement already satisfied: typing-extensions in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.11.0)
Requirement already satisfied: sympy in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.101)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.10.3.66)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.9.0.58)
Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.2.10.91)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.4.0.1)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.4.91)
Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.14.3)
Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.91)
Requirement already satisfied: triton==2.0.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: setuptools in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (69.5.1)
Requirement already satisfied: wheel in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (0.43.0)
Requirement already satisfied: cmake in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (3.29.3)
Requirement already satisfied: lit in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (18.1.4)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
DGL installed!
加载数据集
ogbn-arxiv
已经在 GraphBolt 中作为 BuiltinDataset
准备好了。
[2]:
dataset = gb.BuiltinDataset("ogbn-arxiv-seeds").load()
Downloading datasets/ogbn-arxiv-seeds.zip from https://data.dgl.ai/dataset/graphbolt/ogbn-arxiv-seeds.zip...
Extracting file to datasets
The dataset is already preprocessed.
数据集由图、特征和任务组成。您可以从任务中获取训练-验证-测试集。种子节点和相应的标签已经存储在每个训练-验证-测试集中。其他元数据,如类别数量,也存储在任务中。在这个数据集中,只有一个任务:node classification
。
[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
task_name = dataset.tasks[0].metadata["name"]
num_classes = dataset.tasks[0].metadata["num_classes"]
print(f"Task: {task_name}. Number of classes: {num_classes}")
Task: node_classification. Number of classes: 40
DGL如何处理计算依赖¶
单个节点的消息传递的计算依赖可以描述为一系列消息流图(MFG)。
在DGL中定义邻居采样器和数据加载器
DGL 提供了工具来以小批量迭代数据集,同时生成计算依赖关系,以使用上述的 MFGs 计算它们的输出。对于节点分类,你可以使用 dgl.graphbolt.DataLoader
来迭代数据集。它接受一个数据管道,该管道生成节点及其标签的小批量,为每个节点采样邻居,并以 MFGs 的形式生成计算依赖关系。还支持特征获取、块创建和复制到目标设备。所有这些操作都被分成数据管道中的不同阶段,以便你可以通过插入自己的操作来自定义数据管道。
假设每个节点将在每一层从4个邻居收集消息。定义数据加载器和邻居采样器的代码将如下所示。
[4]:
def create_dataloader(itemset, shuffle):
datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_neighbor(graph, [4, 4])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
return gb.DataLoader(datapipe)
你可以遍历数据加载器,并且会生成一个MiniBatch
对象。
[5]:
data = next(iter(create_dataloader(train_set, shuffle=True)))
print(data)
MiniBatch(seeds=tensor([ 78324, 71698, 128486, ..., 103646, 147000, 3646]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 1, 3, ..., 6095, 6099, 6099], dtype=torch.int32),
indices=tensor([1024, 1025, 1026, ..., 7023, 7025, 7026], dtype=torch.int32),
),
original_row_node_ids=tensor([ 78324, 71698, 128486, ..., 161557, 165481, 67906]),
original_edge_ids=None,
original_column_node_ids=tensor([ 78324, 71698, 128486, ..., 62400, 119968, 32427]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 1, 3, ..., 2157, 2159, 2160], dtype=torch.int32),
indices=tensor([1024, 1025, 1026, ..., 3145, 3146, 3147], dtype=torch.int32),
),
original_row_node_ids=tensor([ 78324, 71698, 128486, ..., 62400, 119968, 32427]),
original_edge_ids=None,
original_column_node_ids=tensor([ 78324, 71698, 128486, ..., 103646, 147000, 3646]),
)],
node_features={'feat': tensor([[-0.0984, -0.0103, -0.0498, ..., 0.1173, 0.1543, -0.1799],
[-0.1082, 0.0946, -0.1910, ..., 0.0996, 0.1690, 0.1150],
[-0.2827, 0.0679, -0.1128, ..., 0.1666, 0.0382, -0.0793],
...,
[-0.1164, -0.0217, -0.0913, ..., 0.1816, -0.1474, -0.1094],
[-0.1101, 0.0444, -0.1484, ..., 0.1235, -0.0812, -0.0271],
[-0.1057, 0.1120, -0.1044, ..., 0.1743, -0.1070, -0.2529]])},
labels=tensor([28, 8, 9, ..., 16, 28, 5]),
input_nodes=tensor([ 78324, 71698, 128486, ..., 161557, 165481, 67906]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
blocks=[Block(num_src_nodes=7027, num_dst_nodes=3148, num_edges=6099),
Block(num_src_nodes=3148, num_dst_nodes=1024, num_edges=2160)],
)
你可以从MFGs中获取输入节点的ID。
[6]:
mfgs = data.blocks
input_nodes = mfgs[0].srcdata[dgl.NID]
print(f"Input nodes: {input_nodes}.")
Input nodes: tensor([ 78324, 71698, 128486, ..., 161557, 165481, 67906]).
定义模型
让我们考虑训练一个带有邻居采样的2层GraphSAGE。模型可以写成如下形式:
[7]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
self.h_feats = h_feats
def forward(self, mfgs, x):
h = self.conv1(mfgs[0], x)
h = F.relu(h)
h = self.conv2(mfgs[1], h)
return h
in_size = feature.size("node", None, "feat")[0]
model = Model(in_size, 64, num_classes).to(device)
定义训练循环
以下初始化模型并定义优化器。
[8]:
opt = torch.optim.Adam(model.parameters())
在计算模型选择的验证分数时,通常你也可以进行邻居采样。我们可以直接重用我们的create_dataloader函数来为训练和验证创建两个独立的数据加载器。
[9]:
train_dataloader = create_dataloader(train_set, shuffle=True)
valid_dataloader = create_dataloader(valid_set, shuffle=False)
import sklearn.metrics
以下是一个在每个epoch执行验证的训练循环。它还将具有最佳验证准确率的模型保存到文件中。
[10]:
from tqdm.auto import tqdm
for epoch in range(10):
model.train()
with tqdm(train_dataloader) as tq:
for step, data in enumerate(tq):
x = data.node_features["feat"]
labels = data.labels
predictions = model(data.blocks, x)
loss = F.cross_entropy(predictions, labels)
opt.zero_grad()
loss.backward()
opt.step()
accuracy = sklearn.metrics.accuracy_score(
labels.cpu().numpy(),
predictions.argmax(1).detach().cpu().numpy(),
)
tq.set_postfix(
{"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
refresh=False,
)
model.eval()
predictions = []
labels = []
with tqdm(valid_dataloader) as tq, torch.no_grad():
for data in tq:
x = data.node_features["feat"]
labels.append(data.labels.cpu().numpy())
predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)
accuracy = sklearn.metrics.accuracy_score(labels, predictions)
print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
Epoch 0 Validation Accuracy 0.4000469814423303
Epoch 1 Validation Accuracy 0.49991610456726737
Epoch 2 Validation Accuracy 0.5320648343904157
Epoch 3 Validation Accuracy 0.5523675291117152
Epoch 4 Validation Accuracy 0.5591127219034195
Epoch 5 Validation Accuracy 0.5694486392160811
Epoch 6 Validation Accuracy 0.5770327863351119
Epoch 7 Validation Accuracy 0.5784422296050203
Epoch 8 Validation Accuracy 0.5827376757609316
Epoch 9 Validation Accuracy 0.5855230041276553
结论
在本教程中,您已经学习了如何使用邻居采样训练多层GraphSAGE。