torch_geometric.data.lightning.LightningNodeData

class LightningNodeData(data: Union[Data, HeteroData], input_train_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_train_time: Optional[Tensor] = None, input_val_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_val_time: Optional[Tensor] = None, input_test_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_test_time: Optional[Tensor] = None, input_pred_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_pred_time: Optional[Tensor] = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any)[source]

基础类:LightningData

DataHeteroData对象转换为 pytorch_lightning.LightningDataModule变体。然后它可以 通过 PyTorch Lightning自动用作多GPU节点级训练的datamoduleLightningDataset将通过 NeighborLoader负责提供小批量数据。

注意

目前仅支持 pytorch_lightning.strategies.SingleDeviceStrategypytorch_lightning.strategies.DDPStrategy 训练 策略,以便在所有设备/进程之间正确共享数据:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters:
  • data (DataHeteroData) – DataHeteroData 图对象。

  • input_train_nodes (torch.Tensorstr(str, torch.Tensor)) – 训练节点的索引。 如果未提供,将尝试通过搜索 data 对象中的 train_masktrain_idxtrain_index 属性来自动推断它们。 (默认值: None)

  • input_train_time (torch.Tensor, optional) – 训练节点的时间戳。(默认值:None

  • input_val_nodes (torch.Tensorstr(str, torch.Tensor)) – 验证节点的索引。 如果未提供,将尝试通过搜索 data 对象中的 val_maskvalid_maskval_idxvalid_idxval_indexvalid_index 属性来自动推断它们。 (默认值: None)

  • input_val_time (torch.Tensor, optional) – 验证边的时间戳。(默认值:None

  • input_test_nodes (torch.Tensorstr(str, torch.Tensor)) – 测试节点的索引。 如果未提供,将尝试通过搜索 data 对象中的 test_masktest_idxtest_index 属性来自动推断它们。 (默认值: None)

  • input_test_time (torch.Tensor, optional) – 测试节点的时间戳。(默认值:None

  • input_pred_nodes (torch.Tensorstr(str, torch.Tensor)) – 预测节点的索引。 如果未提供,将尝试通过搜索 data 对象中的 pred_maskpred_idxpred_index 属性来自动推断它们。 (默认值: None)

  • input_pred_time (torch.Tensor, optional) – 预测节点的时间戳。(默认值: None)

  • loader (str) – 使用的可扩展性技术 ("full", "neighbor")。(默认: "neighbor")

  • node_sampler (BaseSampler, 可选) – 一个自定义的采样器对象,用于生成小批量数据。如果设置了该选项,将忽略 loader 选项。(默认值:None

  • eval_loader_kwargs (Dict[str, Any], optional) – 自定义关键字参数 用于在评估期间覆盖 torch_geometric.loader.NeighborLoader 的配置。(默认值:None

  • **kwargs (可选) – torch_geometric.loader.NeighborLoader 的额外参数。