PyTorch Lightning 集成

PyTorch Lightning 集成。

PyTorch Lightning 提供了一种替代方法来实现知识图谱嵌入模型的训练循环和评估循环,具有一些很好的特性:

  • 混合精度训练

  • 多GPU训练

model = LitLCWAModule(
    dataset="fb15k237",
    dataset_kwargs=dict(create_inverse_triples=True),
    model="mure",
    model_kwargs=dict(embedding_dim=128, loss="bcewithlogits"),
    batch_size=128,
)
trainer = pytorch_lightning.Trainer(
    accelerator="auto",  # automatically choose accelerator
    logger=False,  # defaults to TensorBoard; explicitly disabled here
    precision=16,  # mixed precision training
)
trainer.fit(model=model)

LitModule([dataset, dataset_kwargs, mode, ...])

一个用于使用PyTorch Lightning训练模型的基础模块。

LCWALitModule([dataset, dataset_kwargs, ...])

一个用于使用LCWA训练循环训练模型的PyTorch Lightning模块。

SLCWALitModule(*[, negative_sampler, ...])

一个用于使用sLCWA训练循环训练模型的PyTorch Lightning模块。