DGL
latest

开始使用

  • 安装和设置
  • DGL 快速入门

先进材料

  • 🆕 使用GraphBolt进行GNN的随机训练
    • Neighbor Sampling Overview
    • 节点分类
    • Link Prediction
    • 多GPU节点分类
    • 从原始数据组成OnDiskDataset
  • 用户指南
  • 用户指南【包含过时信息】
  • 用户指南[过时的]
  • 🆕 教程: 图变换器
  • 教程: dgl.sparse
  • 在CPU上进行训练
  • 在多GPU上进行训练
  • 分布式训练
  • 使用DGL进行论文研究

API 参考

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • 🆕 dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • 用户自定义函数

注释

  • 为DGL做贡献
  • DGL 外部函数接口 (FFI)
  • 性能基准测试

杂项

  • 常见问题解答 (FAQ)
  • 环境变量
  • 资源
DGL
  • 🆕 Stochastic Training of GNNs with GraphBolt
  • Link Prediction
  • Edit on GitHub

链接预测

Open In Colab GitHub

本教程将展示如何在CoraGraphDataset上训练多层GraphSAGE进行链接预测。该数据集包含2708个节点和10556条边。

在本教程结束时,您将能够

  • 使用DGL的邻居采样组件在目标设备上训练一个用于链接预测的GNN模型。

Install DGL package

[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CPU version in default. If you want to install CUDA version,
# please refer to https://www.dgl.ai/pages/start.html and change runtime type
# accordingly.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages/dgl-2.3-py3.8-linux-x86_64.egg (2.3)
Requirement already satisfied: numpy>=1.14.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (1.10.1)
Requirement already satisfied: networkx>=2.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (3.1)
Requirement already satisfied: requests>=2.19.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (4.66.4)
Requirement already satisfied: psutil>=5.8.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (5.9.8)
Requirement already satisfied: torchdata>=0.5.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from dgl) (0.7.1)
Requirement already satisfied: pandas in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from dgl) (2.0.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from requests>=2.19.0->dgl) (2024.2.2)
Requirement already satisfied: torch>=2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: tzdata>=2022.1 in /home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.8/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: six>=1.5 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.14.0)
Requirement already satisfied: typing-extensions in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.11.0)
Requirement already satisfied: sympy in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.99)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.101)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.10.3.66)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.9.0.58)
Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (10.2.10.91)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.4.0.1)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.4.91)
Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.14.3)
Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (11.7.91)
Requirement already satisfied: triton==2.0.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2.0.0)
Requirement already satisfied: setuptools in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (69.5.1)
Requirement already satisfied: wheel in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=2->torchdata>=0.5.0->dgl) (0.43.0)
Requirement already satisfied: cmake in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (3.29.3)
Requirement already satisfied: lit in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from triton==2.0.0->torch>=2->torchdata>=0.5.0->dgl) (18.1.4)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/envs/latest/lib/python3.8/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
DGL installed!

Loading Dataset

cora 已经在 GraphBolt 中作为 BuiltinDataset 准备好了。

[2]:
dataset = gb.BuiltinDataset("cora-seeds").load()
Downloading datasets/cora-seeds.zip from https://data.dgl.ai/dataset/graphbolt/cora-seeds.zip...
Extracting file to datasets
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.

数据集由图、特征和任务组成。您可以从任务中获取训练-验证-测试集。种子节点和相应的标签已经存储在每个训练-验证-测试集中。该数据集包含2个任务,一个用于节点分类,另一个用于链接预测。我们将使用链接预测任务。

[3]:
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
print(f"Task: {task_name}.")
Task: link_prediction.

在DGL中定义邻居采样器和数据加载器

与全图的链接预测教程不同,在大图上训练GNN的常见做法是以小批量迭代边,因为计算所有边的概率通常是不可能的。对于每批边,您可以使用邻居采样和GNN计算其关联节点的输出表示,类似于节点分类教程中介绍的方式。

为了执行链接预测,您需要指定一个负采样器。DGL 提供了内置的负采样器,例如 dgl.graphbolt.UniformNegativeSampler。在本教程中,每个正样本均匀地抽取5个负样本。

除了负采样器外,其余代码与节点分类教程相同。

[4]:
from functools import partial
datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(graph, [5, 5])
datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
train_dataloader = gb.DataLoader(datapipe)

你可以从train_dataloader中查看一个小批次,看看它会给你什么。

[5]:
data = next(iter(train_dataloader))
print(f"MiniBatch: {data}")
MiniBatch: MiniBatch(seeds=tensor([[ 359, 1620],
                        [ 195,   94],
                        [ 382,  191],
                        ...,
                        [ 897, 2143],
                        [ 897,  480],
                        [ 897,  335]], dtype=torch.int32),
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    2,    6,  ..., 6884, 6886, 6891], dtype=torch.int32),
                                                                         indices=tensor([1274,   88,  928,  ..., 1271, 2511, 2512], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 359, 1620,  195,  ..., 1491, 2099, 2426], dtype=torch.int32),
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([ 359, 1620,  195,  ..., 1079, 1270, 2425], dtype=torch.int32),
                            ),
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([   0,    2,    6,  ..., 3689, 3692, 3694], dtype=torch.int32),
                                                                         indices=tensor([1274,   88,   88,  ...,   62,  251,  320], dtype=torch.int32),
                                                           ),
                                               original_row_node_ids=tensor([ 359, 1620,  195,  ..., 1079, 1270, 2425], dtype=torch.int32),
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([ 359, 1620,  195,  ..., 1622,  752,  335], dtype=torch.int32),
                            )],
          node_features={'feat': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                                [0.0000, 0.0000, 0.0526,  ..., 0.0000, 0.0000, 0.0000],
                                [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                                ...,
                                [0.0000, 0.0476, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                                [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                                [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])},
          labels=tensor([1., 1., 1.,  ..., 0., 0., 0.]),
          input_nodes=tensor([ 359, 1620,  195,  ..., 1491, 2099, 2426], dtype=torch.int32),
          indexes=tensor([  0,   1,   2,  ..., 255, 255, 255]),
          edge_features=[{},
                        {}],
          compacted_seeds=tensor([[   0,    1],
                                  [   2,    3],
                                  [   4,    5],
                                  ...,
                                  [  23,  924],
                                  [  23,  950],
                                  [  23, 1273]], dtype=torch.int32),
          blocks=[Block(num_src_nodes=2513, num_dst_nodes=2218, num_edges=6891),
                 Block(num_src_nodes=2218, num_dst_nodes=1274, num_edges=3694)],
       )

定义节点表示的模型

让我们考虑训练一个带有邻居采样的2层GraphSAGE。模型可以写成如下形式:

[6]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

定义训练循环

以下初始化模型并定义优化器。

[7]:
in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

以下是用于链接预测和评估的训练循环。

[8]:
from tqdm.auto import tqdm
for epoch in range(3):
    model.train()
    total_loss = 0
    for step, data in tqdm(enumerate(train_dataloader)):
        # Get node pairs with labels for loss calculation.
        compacted_seeds = data.compacted_seeds.T
        labels = data.labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")
Epoch 000 | Loss 0.559
Epoch 001 | Loss 0.449
Epoch 002 | Loss 0.445

使用链接预测评估性能

[9]:
model.eval()

datapipe = gb.ItemSampler(test_set, batch_size=256, shuffle=False)
datapipe = datapipe.copy_to(device)
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
datapipe = datapipe.sample_neighbor(graph, [-1, -1])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
eval_dataloader = gb.DataLoader(datapipe, num_workers=0)

logits = []
labels = []
for step, data in tqdm(enumerate(eval_dataloader)):
    # Get node pairs with labels for loss calculation.
    compacted_seeds = data.compacted_seeds.T
    label = data.labels

    # The features of sampled nodes.
    x = data.node_features["feat"]

    # Forward.
    y = model(data.blocks, x)
    logit = (
        model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])
        .squeeze()
        .detach()
    )

    logits.append(logit)
    labels.append(label)

logits = torch.cat(logits, dim=0)
labels = torch.cat(labels, dim=0)


# Compute the AUROC score.
from sklearn.metrics import roc_auc_score

auc = roc_auc_score(labels.cpu(), logits.cpu())
print("Link Prediction AUC:", auc)
Link Prediction AUC: 0.6802318007232542

结论

在本教程中,您已经学习了如何使用邻居采样训练多层GraphSAGE进行链接预测。

Previous Next

© Copyright 2018, DGL Team. Revision 2ee440a6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
2.2.x
2.1.x
2.0.x
1.1.x
1.0.x
0.9.x
0.8.x
0.7.x
0.6.x
Downloads
On Read the Docs
Project Home
Builds