torch_geometric.data.lightning.LightningNodeData
- class LightningNodeData(data: Union[Data, HeteroData], input_train_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_train_time: Optional[Tensor] = None, input_val_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_val_time: Optional[Tensor] = None, input_test_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_test_time: Optional[Tensor] = None, input_pred_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_pred_time: Optional[Tensor] = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any)[source]
基础类:
LightningData将
Data或HeteroData对象转换为pytorch_lightning.LightningDataModule变体。然后它可以 通过 PyTorch Lightning自动用作多GPU节点级训练的datamodule。LightningDataset将通过NeighborLoader负责提供小批量数据。注意
目前仅支持
pytorch_lightning.strategies.SingleDeviceStrategy和pytorch_lightning.strategies.DDPStrategy训练 策略,以便在所有设备/进程之间正确共享数据:import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule)
- Parameters:
data (Data 或 HeteroData) –
Data或HeteroData图对象。input_train_nodes (torch.Tensor 或 str 或 (str, torch.Tensor)) – 训练节点的索引。 如果未提供,将尝试通过搜索
data对象中的train_mask、train_idx或train_index属性来自动推断它们。 (默认值:None)input_train_time (torch.Tensor, optional) – 训练节点的时间戳。(默认值:
None)input_val_nodes (torch.Tensor 或 str 或 (str, torch.Tensor)) – 验证节点的索引。 如果未提供,将尝试通过搜索
data对象中的val_mask、valid_mask、val_idx、valid_idx、val_index或valid_index属性来自动推断它们。 (默认值:None)input_val_time (torch.Tensor, optional) – 验证边的时间戳。(默认值:
None)input_test_nodes (torch.Tensor 或 str 或 (str, torch.Tensor)) – 测试节点的索引。 如果未提供,将尝试通过搜索
data对象中的test_mask、test_idx或test_index属性来自动推断它们。 (默认值:None)input_test_time (torch.Tensor, optional) – 测试节点的时间戳。(默认值:
None)input_pred_nodes (torch.Tensor 或 str 或 (str, torch.Tensor)) – 预测节点的索引。 如果未提供,将尝试通过搜索
data对象中的pred_mask、pred_idx或pred_index属性来自动推断它们。 (默认值:None)input_pred_time (torch.Tensor, optional) – 预测节点的时间戳。(默认值:
None)loader (str) – 使用的可扩展性技术 (
"full","neighbor")。(默认:"neighbor")node_sampler (BaseSampler, 可选) – 一个自定义的采样器对象,用于生成小批量数据。如果设置了该选项,将忽略
loader选项。(默认值:None)eval_loader_kwargs (Dict[str, Any], optional) – 自定义关键字参数 用于在评估期间覆盖
torch_geometric.loader.NeighborLoader的配置。(默认值:None)**kwargs (可选) –
torch_geometric.loader.NeighborLoader的额外参数。