torch_geometric.data.lightning.LightningDataset

class LightningDataset(train_dataset: 数据集, val_dataset: Optional[数据集] = None, test_dataset: Optional[数据集] = None, pred_dataset: Optional[数据集] = None, **kwargs: Any)[source]

基础类:LightningDataModule

将一组Dataset对象转换为pytorch_lightning.LightningDataModule变体。然后它可以自动用作datamodule,通过PyTorch Lightning进行多GPU图级训练。LightningDataset将通过DataLoader提供小批量数据。

注意

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters: