6.9 数据加载并行性

在GNN的小批量训练中,我们通常需要涵盖几个阶段来生成一个小批量,包括:

  • 遍历项目集并以批量大小生成小批量种子。

  • 从图中为每个种子样本抽取负面项目。

  • 从图中为每个种子样本邻居。

  • 从采样的子图中排除种子边。

  • 获取采样子图的节点和边特征。

  • 将MiniBatches复制到目标设备。

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(g, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

所有这些阶段都在单独的 IterableDataPipe 中实现,并与 PyTorch DataLoader 堆叠在一起。 这种设计使我们能够通过将不同的数据管道链接在一起来轻松自定义数据加载过程。例如,如果我们想从图中为每个种子采样负样本,我们可以简单地在 ItemSampler 之后链接 NegativeSampler

但是,仅仅将数据管道串联在一起会导致性能开销,因为不同的阶段会利用各种硬件资源,如CPU、GPU、PCIe等。因此,数据加载机制被优化以最小化这些开销并实现最佳性能。

具体来说,GraphBolt在fetch_feature之前用多进程包装数据管道,这使得多个进程可以并行运行。至于fetch_feature数据管道,我们保持其在主进程中运行,以避免进程间的数据移动开销。

此外,为了重叠数据移动和模型计算,我们在copy_to之前使用 torchdata.datapipes.iter.Perfetcher 包装数据管道,该管道从先前的数据管道预取元素并将其放入缓冲区。 这种预取对用户完全透明,不需要额外的代码。它为GNN的小批量训练带来了显著的性能提升。

请参考DataLoader的源代码以获取更多详细信息。