torch_geometric.data.lightning.LightningLinkData

class LightningLinkData(data: Union[Data, HeteroData], input_train_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_train_labels: Optional[Tensor] = None, input_train_time: Optional[Tensor] = None, input_val_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_val_labels: Optional[Tensor] = None, input_val_time: Optional[Tensor] = None, input_test_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_test_labels: Optional[Tensor] = None, input_test_time: Optional[Tensor] = None, input_pred_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_pred_labels: Optional[Tensor] = None, input_pred_time: Optional[Tensor] = None, loader: str = 'neighbor', link_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any)[source]

Bases: LightningData

DataHeteroData对象转换为 pytorch_lightning.LightningDataModule变体。然后它可以 自动用作通过 PyTorch Lightning进行多GPU链接级训练的 datamoduleLightningDataset将通过 LinkNeighborLoader负责提供小批量数据。

注意

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters:
  • data (DataHeteroDataTuple[FeatureStore, GraphStore]) – DataHeteroData 图对象,或 FeatureStoreGraphStore 对象的元组。

  • input_train_edges (TensorEdgeTypeTuple[EdgeType, Tensor]) – 训练边。 (默认: None)

  • input_train_labels (torch.Tensor, optional) – 训练边的标签。(默认: None)

  • input_train_time (torch.Tensor, optional) – 训练边的时间戳。(默认值:None

  • input_val_edges (TensorEdgeTypeTuple[EdgeType, Tensor]) – 验证边。(默认值:None

  • input_val_labels (torch.Tensor, optional) – 验证边的标签。(默认值:None

  • input_val_time (torch.Tensor, optional) – The timestamp of validation edges. (default: None)

  • input_test_edges (TensorEdgeTypeTuple[EdgeType, Tensor]) – 测试边。 (默认: None)

  • input_test_labels (torch.Tensor, optional) – 测试边的标签。(默认值:None

  • input_test_time (torch.Tensor, optional) – 测试边的时间戳。(默认值:None

  • input_pred_edges (TensorEdgeTypeTuple[EdgeType, Tensor]) – 预测边。 (默认: None)

  • input_pred_labels (torch.Tensor, optional) – 预测边的标签。(默认值:None

  • input_pred_time (torch.Tensor, optional) – 预测边的时间戳。(默认值:None

  • loader (str) – The scalability technique to use ("full", "neighbor"). (default: "neighbor")

  • link_sampler (BaseSampler, optional) – 一个自定义的采样器对象,用于生成小批量数据。如果设置了,将忽略 loader 选项。(默认值:None

  • eval_loader_kwargs (Dict[str, Any], optional) – 自定义关键字参数 用于覆盖 torch_geometric.loader.LinkNeighborLoader 配置 在评估期间。(默认值:None

  • **kwargs (可选) – torch_geometric.loader.LinkNeighborLoader 的附加参数。