第一步
训练和评估模型的最简单方法是使用pykeen.pipeline.pipeline()函数。
它提供了一个高级入口点,用于访问此包的可扩展功能。有关管道和相关函数的完整参考文档可以在pykeen.pipeline找到。
训练模型
以下示例展示了如何在pykeen.datasets.Nations数据集上训练和评估pykeen.models.TransE模型。在文档中,您会注意到每个资产在PyKEEN中都有相应的类。您可以点击链接了解更多信息,并查看如何使用它们的参考。不用担心,在本教程的这一部分中,pykeen.pipeline.pipeline()函数将为您处理一切。
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... )
>>> pipeline_result.save_to_directory('nations_transe')
结果返回在一个pykeen.pipeline.PipelineResult实例中,该实例具有训练模型、训练循环和评估的属性。
在这个例子中,模型是以字符串形式给出的。可用的模型列表可以在
pykeen.models 中找到。或者,可以使用与模型实现相对应的类,如下所示:
>>> from pykeen.pipeline import pipeline
>>> from pykeen.models import TransE
>>> pipeline_result = pipeline(
... dataset='Nations',
... model=TransE,
... )
>>> pipeline_result.save_to_directory('nations_transe')
在这个例子中,数据集是以字符串形式给出的。可用的数据集列表可以在
pykeen.datasets 中找到。或者,可以使用 pykeen.datasets.Dataset 的子类,如下所示:
>>> from pykeen.pipeline import pipeline
>>> from pykeen.models import TransE
>>> from pykeen.datasets import Nations
>>> pipeline_result = pipeline(
... dataset=Nations,
... model=TransE,
... )
>>> pipeline_result.save_to_directory('nations_transe')
在前三个例子中,训练方法、优化器和评估方案被省略了。默认情况下,模型是在随机局部封闭世界假设(sLCWA;pykeen.training.SLCWATrainingLoop)下训练的。这可以明确地作为一个字符串给出:
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... training_loop='sLCWA',
... )
>>> pipeline_result.save_to_directory('nations_transe')
或者,模型可以在局部封闭世界假设(LCWA;
pykeen.training.LCWATrainingLoop)下通过提供'LCWA'进行训练。
不需要额外的配置,但值得阅读这些训练方法之间的差异。可用的训练假设列表可以在pykeen.training中找到。
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... training_loop='LCWA',
... )
>>> pipeline_result.save_to_directory('nations_transe')
其中一个区别是sLCWA依赖于负采样。负采样的类型可以如下给出:
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... training_loop='sLCWA',
... negative_sampler='basic',
... )
>>> pipeline_result.save_to_directory('nations_transe')
在这个例子中,负采样器被指定为一个字符串。可用的负采样器列表可以在pykeen.sampling中找到。或者,可以使用负采样器实现对应的类,如下所示:
>>> from pykeen.pipeline import pipeline
>>> from pykeen.sampling import BasicNegativeSampler
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... training_loop='sLCWA',
... negative_sampler=BasicNegativeSampler,
... )
>>> pipeline_result.save_to_directory('nations_transe')
警告
如果正在使用LCWA,则不应使用negative_sampler关键字参数。
一般来说,所有其他选项在任一训练方法下都可用。
评估的类型可以通过evaluator关键字指定。默认情况下,使用基于排名的评估。可以明确给出如下:
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... evaluator='RankBasedEvaluator',
... )
>>> pipeline_result.save_to_directory('nations_transe')
在这个例子中,评估器字符串。可用的评估器列表可以在
pykeen.evaluation 中找到。或者,可以使用与评估器实现对应的类,如下所示:
>>> from pykeen.pipeline import pipeline
>>> from pykeen.evaluation import RankBasedEvaluator
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... evaluator=RankBasedEvaluator,
... )
>>> pipeline_result.save_to_directory('nations_transe')
PyKEEN 实现了早停功能,可以通过 stopper 关键字参数来开启,如下所示:
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... stopper='early',
... )
>>> pipeline_result.save_to_directory('nations_transe')
在 PyKEEN 中,您还可以使用 PyTorch 提供的学习率调度器,可以通过 lr_scheduler 关键字参数以及 lr_scheduler_kwargs 关键字参数来启用,并为学习率调度器指定参数,如下所示:
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... lr_scheduler='ExponentialLR',
... lr_scheduler_kwargs=dict(
... gamma=0.99,
... ),
... )
>>> pipeline_result.save_to_directory('nations_transe')
更深入的配置
模型的参数可以作为字典使用model_kwargs给出。
>>> from pykeen.pipeline import pipeline
>>> pipeline_result = pipeline(
... dataset='Nations',
... model='TransE',
... model_kwargs=dict(
... scoring_fct_norm=2,
... ),
... )
>>> pipeline_result.save_to_directory('nations_transe')
model_kwargs 中的条目对应于传递给 pykeen.models.TransE.__init__() 的参数。有关模型的完整列表,请参见 pykeen.models,其中有每个模型的参考链接,解释了可能的 kwargs。除非在模型的参考页面上另有说明,否则每个模型的默认超参数是根据最初发布模型的论文中报告的最佳值选择的。
因为管道负责查找类并实例化它们,所以pykeen.pipeline.pipeline()还有其他几个参数,可以在各自的实例化过程中用于指定参数。
可以使用dataset_kwargs向数据集提供参数。这些参数会被传递给pykeen.datasets.Nations
加载预训练模型
许多之前的示例都以使用pykeen.pipeline.PipelineResult.save_to_directory()保存结果结束。写入给定目录的其中一个文件是trained_model.pkl。因为所有PyKEEN模型都继承自torch.nn.Module,我们使用PyTorch的机制来保存和加载它们。这意味着你可以使用torch.load()来加载一个模型,如下所示:
import torch
my_pykeen_model = torch.load("trained_model.pkl")
更多关于PyTorch模型持久化的信息可以在以下链接找到: https://pytorch.org/tutorials/beginner/saving_loading_models.html.
将实体和关系标识符映射到它们的名称
虽然PyKEEN在内部将实体和关系映射到连续的标识符,但能够使用实体和关系的标签与数据集、三元组工厂和模型进行交互仍然非常有用。
我们可以使用TriplesFactory.entities_to_ids()将三元组工厂的实体映射到标识符,如下例所示:
import torch
from pykeen.datasets.utils import get_dataset
from pykeen.triples.triples_factory import TriplesFactory
# As an example, we will use a small dataset that comes with entity and relation labels.
dataset = get_dataset(dataset="nations")
triples_factory = dataset.training
# Get tensor of entity identifiers
entity_ids = torch.as_tensor(triples_factory.entities_to_ids(["china", "egypt"]))
同样地,我们可以使用TriplesFactory.relations_to_ids将三元组工厂的关系映射到标识符,如下例所示:
relation_ids = torch.as_tensor(triples_factory.relations_to_ids(["independence", "embassy"]))
警告
重要的是要注意,我们应该使用与训练模型时相同的映射的三元组工厂 - 否则我们可能会得到不正确的ID。
使用学习到的嵌入
为实体和关系学习的嵌入不仅对链接预测有用(参见预测),还对其他下游机器学习任务如聚类、回归和分类有用。
知识图谱嵌入模型可能具有多个实体表示和多个关系表示,因此它们分别存储在每个模型的entity_representations和relation_representations属性中作为序列。虽然这些序列的具体内容取决于模型,但每个序列的第一个元素通常是实体或关系的“主要”表示。
通常,这些序列中的值是pykeen.nn.representation.Embedding的实例。
这实现了一个与内置torch.nn.Embedding类似但更强大的接口。
然而,这些序列中的值更一般地可以是pykeen.nn.representation.Representation的任何子类的实例。
这使得在GNN中实现和使用更强大的编码器成为可能,例如pykeen.models.RGCN。
实体表示和关系表示可以这样访问:
from pykeen.models import ERModel
from pykeen.pipeline import pipeline
# train a model
result = pipeline(model="TransE", dataset="nations")
model = result.model
assert isinstance(model, ERModel)
# access entity and relation representations
entity_representation_modules = model.entity_representations
relation_representation_modules = model.relation_representations
大多数模型,如pykeen.models.TransE,只有一个实体表示和一个关系表示。这意味着entity_representations和relation_representations列表的长度都为1。所有实体嵌入可以像这样访问:
from pykeen.nn.representation import Embedding # noqa: E402
# TransE has one representation for entities and one for relations
# both are simple embedding matrices
entity_embeddings = entity_representation_modules[0]
relation_embeddings = relation_representation_modules[0]
assert isinstance(entity_embeddings, Embedding)
assert isinstance(relation_embeddings, Embedding)
由于所有表示都是torch.nn.Module的子类,你需要像调用函数一样调用它们来触发forward()并获取值。
entity_embedding_tensor = entity_embeddings()
relation_embedding_tensor = relation_embeddings()
所有pykeen.nn.representation.Representation的forward()函数都接受一个indices参数。
默认情况下,它是None并返回所有值。更明确地说,这看起来像:
entity_embedding_tensor = entity_embeddings(indices=None)
relation_embedding_tensor = relation_embeddings(indices=None)
如果您只想查找某些嵌入,可以使用indices参数
并传递一个torch.LongTensor及其相应的索引。
import torch # noqa: E402
entity_embedding_tensor = entity_embeddings(indices=torch.as_tensor([1, 3]))
你可能想要将它们从GPU中分离并转换为numpy.ndarray
entity_embedding_tensor = entity_embeddings.detach().cpu().numpy()
警告
一些旧式模型(例如,继承自pykeen.models.EntityRelationEmbeddingModel的模型)
没有完全实现entity_representations和relation_representations接口。这意味着
它们可能在这些序列中未暴露的属性中存储了额外的嵌入。
例如,pykeen.models.TransD在
pykeen.models.TransD.entity_projections中有一个次要的实体嵌入。
最终,所有模型都将升级为新式模型,这将不再是一个问题。
超越管道
虽然管道提供了一个高级接口,但训练过程的每个方面都被封装在可以更精细调整或子类化的类中。以下是可能已与之前示例之一执行的代码示例。
# Get a training dataset
from pykeen.datasets import get_dataset
dataset = get_dataset(dataset="nations")
training = dataset.training
validation = dataset.validation
testing = dataset.testing
# The following applies to most packaged datasets,
# although the dataset class itself makes `validation' optional.
assert validation is not None
# Pick a model
from pykeen.models import TransE
model = TransE(triples_factory=training)
# Pick an optimizer from PyTorch
from torch.optim import Adam
optimizer = Adam(params=model.get_grad_params())
# Pick a training approach (sLCWA or LCWA)
from pykeen.training import SLCWATrainingLoop
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=training,
optimizer=optimizer,
)
# Train like Cristiano Ronaldo
_ = training_loop.train(
triples_factory=training,
num_epochs=5,
batch_size=256,
)
# Pick an evaluator
from pykeen.evaluation import RankBasedEvaluator
evaluator = RankBasedEvaluator()
# Evaluate
results = evaluator.evaluate(
model=model,
mapped_triples=testing.mapped_triples,
batch_size=1024,
additional_filter_triples=[
training.mapped_triples,
validation.mapped_triples,
],
)
# print(results)
预览:评估循环
PyKEEN 目前正在过渡到使用 torch 的数据加载器进行评估。 虽然对于高层的 pipeline 尚未激活,但您已经可以显式地使用它:
from pykeen.datasets import Nations
from pykeen.evaluation import LCWAEvaluationLoop
from pykeen.models import TransE
from pykeen.training import SLCWATrainingLoop
# get a dataset
dataset = Nations()
# Pick a model
model = TransE(triples_factory=dataset.training)
# Pick a training approach (sLCWA or LCWA)
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=dataset.training,
)
# Train like Cristiano Ronaldo
_ = training_loop.train(
triples_factory=dataset.training,
num_epochs=5,
batch_size=256,
# NEW: validation evaluation callback
callbacks="evaluation-loop",
callbacks_kwargs=dict(
prefix="validation",
factory=dataset.validation,
),
)
# Pick an evaluation loop (NEW)
evaluation_loop = LCWAEvaluationLoop(
model=model,
triples_factory=dataset.testing,
)
# Evaluate
results = evaluation_loop.evaluate()
# print(results)
训练回调
PyKEEN 允许通过回调与训练循环进行交互。 一个特定的用例是定期评估(不包括早停器)。 以下示例展示了如何在每十个周期对训练三元组进行评估
from pykeen.datasets import get_dataset
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations")
result = pipeline(
dataset=dataset,
model="mure",
training_kwargs=dict(
num_epochs=100,
callbacks="evaluation",
callbacks_kwargs=dict(
evaluation_triples=dataset.training.mapped_triples,
tracker="console",
prefix="training",
),
),
)
有关不同结果跟踪器的更多信息,请查看结果跟踪器部分。
下一步
入门教程教你如何为一些最常见的任务训练和使用模型。在文档的这一部分中,还有其他几个特定主题的教程。如果你遇到问题,你可能还想跳到故障排除部分,或者查看其他人在GitHub上发布的问题和讨论。