GraphDataLoader

class dgl.dataloading.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]

Bases: DataLoader

批量图数据加载器。

PyTorch dataloader 用于对一组图进行批量迭代,生成所述小批量的批量图和相应的标签张量(如果提供)。

Parameters:
  • 数据集 (torch.utils.data.Dataset) – 从中加载图形的数据集。

  • collate_fn (Function, 默认值为 None) – 自定义的整理函数。如果未提供,将使用默认的整理函数。

  • use_ddp (boolean, optional) –

    If True, tells the DataLoader to split the training set for each participating process appropriately using torch.utils.data.distributed.DistributedSampler.

    Overrides the sampler argument of torch.utils.data.DataLoader.

  • ddp_seed (int, optional) –

    The seed for shuffling the dataset in torch.utils.data.distributed.DistributedSampler.

    Only effective when use_ddp is True.

  • kwargs (dict) –

    Key-word arguments to be passed to the parent PyTorch torch.utils.data.DataLoader class. Common arguments are:

    • batch_size (int): The number of indices in each batch.

    • drop_last (bool): Whether to drop the last incomplete batch.

    • shuffle (bool): Whether to randomly shuffle the indices at each epoch.

示例

为了在dataset中的一组图上训练一个用于图分类的GNN:

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
...     train_on(batched_graph, labels)

使用分布式数据并行

如果你正在使用PyTorch的分布式训练(例如,当使用 torch.nn.parallel.DistributedDataParallel时),你可以通过开启 use_ddp选项来训练模型:

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for batched_graph, labels in dataloader:
...         train_on(batched_graph, labels)