DGL
latest

开始使用

  • 安装和设置
  • DGL 快速入门

先进材料

  • 🆕 使用GraphBolt进行GNN的随机训练
  • 用户指南
    • 第1章:图
    • 第2章:消息传递
    • 第3章:构建GNN模块
    • 第4章:图数据管道
    • 第5章:训练图神经网络
    • 第6章:大型图上的随机训练
      • 6.1 使用邻域采样训练GNN进行节点分类
      • 6.2 使用邻域采样训练GNN进行边分类
      • 6.3 使用邻域采样训练GNN进行链接预测
      • 6.4 实现自定义图采样器
      • 6.5 使用DGL稀疏训练GNN
      • 6.6 实现用于小批量训练的自定义GNN模块
      • 6.7 大型图上的精确离线推理
      • 6.8 使用GPU进行邻域采样
      • 6.9 数据加载并行性
    • 第7章:分布式训练
    • 第8章:混合精度训练
  • 用户指南【包含过时信息】
  • 用户指南[过时的]
  • 🆕 教程: 图变换器
  • 教程: dgl.sparse
  • 在CPU上进行训练
  • 在多GPU上进行训练
  • 分布式训练
  • 使用DGL进行论文研究

API 参考

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • 🆕 dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • 用户自定义函数

注释

  • 为DGL做贡献
  • DGL 外部函数接口 (FFI)
  • 性能基准测试

杂项

  • 常见问题解答 (FAQ)
  • 环境变量
  • 资源
DGL
  • User Guide
  • Chapter 6: Stochastic Training on Large Graphs
  • 6.3 Training GNN for Link Prediction with Neighborhood Sampling
  • Edit on GitHub

6.3 使用邻域采样训练GNN进行链接预测

(中文版)

定义一个带有邻居和负采样的数据加载器

你仍然可以使用与节点/边分类中相同的数据加载器。 唯一的区别是,你需要在邻居采样阶段之前添加一个额外的阶段 负采样。以下数据加载器 将为每条边的源节点均匀地选择5个负目标节点。

datapipe = datapipe.sample_uniform_negative(graph, 5)

整个数据加载器管道如下:

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

有关内置均匀负采样器的详细信息,请参阅 UniformNegativeSampler。

你也可以提供自己的负采样器函数,只要它继承自NegativeSampler并重写_sample_with_etype()方法,该方法接收小批量中的节点对,并返回负节点对。

以下给出了一个自定义负采样器的示例,该采样器根据与度数幂成比例的概率分布采样负目标节点。

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = node_degrees ** 0.75
        self.k = k

    def _sample_with_etype(self, seeds, etype=None):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

为小批量训练定义一个GraphSAGE模型

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

当提供负采样器时,数据加载器将为每个小批量生成正负节点对,除了消息流图(MFGs)。使用compacted_seeds和labels来获取压缩的节点对和相应的标签。

训练循环

训练循环简单地涉及遍历数据加载器,并将图以及输入特征提供给上面定义的模型。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds.T
        labels = data.labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()

DGL 提供了 无监督学习 GraphSAGE ,展示了在同质图上进行链接预测的示例。

对于异构图

之前的模型可以很容易地扩展到异构图。唯一的区别是你需要使用HeteroGraphConv来根据边类型包装SAGEConv。

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

数据加载器的定义与同构图非常相似。唯一的区别是您需要为特征获取提供边类型。

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(
    feature,
    node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

如果你想提供自己的负采样函数,只需继承自 NegativeSampler 类并重写 _sample_with_etype() 方法。

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = {
            etype: node_degrees[etype] ** 0.75 for etype in node_degrees
        }
        self.k = k

    def _sample_with_etype(self, seeds, etype):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights[etype].multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

对于异构图,节点对按边类型分组。训练循环与同构图几乎相同,除了在特定边类型上计算损失。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

category = "user"
for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds
        labels = data.labels
        node_features = {
            ntype: data.node_features[(ntype, "feat")]
            for ntype in data.blocks[0].srctypes
        }
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks
        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[category][compacted_pairs[category][:, 0]]
            * y[category][compacted_pairs[category][:, 1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels[category])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()
Previous Next

© Copyright 2018, DGL Team. Revision 2ee440a6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
2.2.x
2.1.x
2.0.x
1.1.x
1.0.x
0.9.x
0.8.x
0.7.x
0.6.x
Downloads
On Read the Docs
Project Home
Builds