分区表示

class PartitionRepresentation(assignment: Tensor, shape: int | Sequence[int] | None = None, bases: str | Representation | type[Representation] | None | Sequence[str | Representation | type[Representation] | None] = None, bases_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, **kwargs)[source]

基础类:Representation

将索引划分为不同的表示模块。

每个索引都被分配到基础表示中的一个索引。这种表示方式在例如当其中一个基础表示无法为每个索引提供向量时,另一个表示被用作备用时非常有用。

考虑以下示例:我们只有两个实体的文本信息。我们希望使用从它们计算出的文本特征,这些特征不应该被训练。对于其余的实体,我们希望直接使用可训练的嵌入。

我们首先为那些有标签的实体创建表示:

>>> from pykeen.nn import Embedding, init
>>> num_entities = 5
>>> labels = {1: "a first description", 4: "a second description"}
>>> label_initializer = init.LabelBasedInitializer(labels=list(labels.values()))
>>> label_repr = label_initializer.as_embedding()

接下来,我们为剩余的部分创建表示

>>> non_label_repr = Embedding(max_id=num_entities - len(labels), shape=label_repr.shape)

要将它们组合成一个单一的表示模块,我们首先需要定义分配,即在哪里查找全局ID。为此,我们创建一个形状为(num_entities, 2)的张量,其中包含基础表示的索引以及此表示中的本地索引。

>>> import torch
>>> assignment = torch.as_tensor([(1, 0), (0, 0), (1, 1), (1, 2), (0, 1)])
>>> from pykeen.nn import PartitionRepresentation
>>> entity_repr = PartitionRepresentation(assignment=assignment, bases=[label_repr, non_label_repr])

为了简洁起见,我们在这里使用随机生成的三元组工厂而不是实际数据

>>> from pykeen.triples.generation import generate_triples_factory
>>> training = generate_triples_factory(num_entities=num_entities, num_relations=5, num_triples=31)
>>> testing = generate_triples_factory(num_entities=num_entities, num_relations=5, num_triples=17)

组合表示现在可以像其他表示一样使用,例如,用于训练DistMult模型:

>>> from pykeen.pipeline import pipeline
>>> from pykeen.models import ERModel
>>> pipeline(
...     model=ERModel,
...     interaction="distmult",
...     model_kwargs=dict(
...         entity_representation=entity_repr,
...         relation_representation_kwargs=dict(shape=shape),
...     ),
...     training=training,
...     testing=testing,
... )

初始化表示。

警告

基础表示必须具有一致的形状

Parameters:
  • assignment (Tensor) – 形状: (max_id, 2) 分配,作为元组 (base_id, local_id),其中 base_id 指的是基础表示的索引,而 local_id 是用于在基础表示中查找的索引

  • shape (tuple[int, ...]) – 单个表示的形状。如果提供,必须与基础的形状匹配

  • bases (OneOrSequence[HintOrType[Representation]]) – 基础表示或其提示。

  • bases_kwargs (OneOrSequence[OptionalKwargs]) – 用于实例化基础表示的关键字参数

  • kwargs – 传递给 Representation.__init__() 的额外基于关键字的参数。不能包含 max_idshape,这些是从基础表示中推断出来的。

Raises:

ValueError – 如果任何输入无效