GraphDataLoader
- class dgl.dataloading.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]
Bases:
DataLoader
批量图数据加载器。
PyTorch dataloader 用于对一组图进行批量迭代,生成所述小批量的批量图和相应的标签张量(如果提供)。
- Parameters:
数据集 (torch.utils.data.Dataset) – 从中加载图形的数据集。
collate_fn (Function, 默认值为 None) – 自定义的整理函数。如果未提供,将使用默认的整理函数。
use_ddp (boolean, optional) –
If True, tells the DataLoader to split the training set for each participating process appropriately using
torch.utils.data.distributed.DistributedSampler
.Overrides the
sampler
argument oftorch.utils.data.DataLoader
.ddp_seed (int, optional) –
The seed for shuffling the dataset in
torch.utils.data.distributed.DistributedSampler
.Only effective when
use_ddp
is True.kwargs (dict) –
Key-word arguments to be passed to the parent PyTorch
torch.utils.data.DataLoader
class. Common arguments are:batch_size
(int): The number of indices in each batch.drop_last
(bool): Whether to drop the last incomplete batch.shuffle
(bool): Whether to randomly shuffle the indices at each epoch.
示例
为了在
dataset
中的一组图上训练一个用于图分类的GNN:>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)
使用分布式数据并行
如果你正在使用PyTorch的分布式训练(例如,当使用
torch.nn.parallel.DistributedDataParallel
时),你可以通过开启use_ddp
选项来训练模型:>>> dataloader = dgl.dataloading.GraphDataLoader( ... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for epoch in range(start_epoch, n_epochs): ... dataloader.set_epoch(epoch) ... for batched_graph, labels in dataloader: ... train_on(batched_graph, labels)