表示法

class Representation(max_id: int, shape: int | Sequence[int] = 64, normalizer: str | Callable[[Tensor], Tensor] | type[Callable[[Tensor], Tensor]] | None = None, normalizer_kwargs: Mapping[str, Any] | None = None, regularizer: str | Regularizer | type[Regularizer] | None = None, regularizer_kwargs: Mapping[str, Any] | None = None, dropout: float | None = None, unique: bool | None = None)[源代码]

基础类: Module, ExtraReprMixin, ABC

用于获取实体/关系表示的基础类。

表示模块将整数ID映射到表示,这些表示是浮点数的张量。

max_id 定义了我们允许请求的索引的上限(不包括该上限)。对于简单的嵌入,这等同于 num_embeddings,但对于一般的非嵌入表示来说,这是一个更合适的词,因为这些表示可能来自其他地方,例如一个 GNN 编码器。

shape 描述了一个单一表示的形状。在向量嵌入的情况下,这只是一个单一的维度。对于其他情况,例如 pykeen.models.RESCAL,我们有二维表示,通常它可以是任何固定的形状。

我们可以将所有表示视为形状为(max_id, *shape)的张量,这正是将indices=None传递给前向方法的结果。

我们还可以将多维的索引传递给forward方法,在这种情况下,索引的形状将成为结果形状的前缀:(*indices.shape, *self.shape)

初始化表示模块。

Parameters:
  • max_id (int) – 最大ID(不包括)。有效的ID范围从0到max_id-1

  • shape (tuple[int, ...]) – 单个表示的形状。

  • normalizer (Callable[[Tensor], Tensor] | None) – 一个归一化函数,它在每次前向传递中应用于选定的表示。

  • normalizer_kwargs (OptionalKwargs) – 传递给normalizer的额外关键字参数

  • regularizer (Regularizer | None) – 一个输出正则化器,在前向传递中应用于选定的表示

  • regularizer_kwargs (OptionalKwargs) – 传递给正则化器的额外关键字参数

  • dropout (Dropout | None) – 可选的dropout概率

  • unique (bool | None) – 是否优化为仅对相同索引计算一次表示。这仅在表示的计算比基于索引的查找显著更昂贵且预期有重复索引时有用,例如,在使用负采样和大批量时。

属性摘要

device

返回设备。

方法总结

forward([indices])

获取索引的表示。

iter_extra_repr()

遍历组件以用于 extra_repr()

post_parameter_update()

应用不应包含在梯度中的约束。

reset_parameters()

重置模块的参数。

属性文档

device

返回设备。

方法文档

forward(indices: Tensor | None = None) Tensor[source]

获取索引的表示。

注意

根据 Representation.unique,此实现将使用针对重复索引的优化。通常仅在计算单个表示成本较高时推荐使用,例如,因为它涉及消息传递或大型编码器网络,但对于成本较低的查找(例如,普通的嵌入查找)则不推荐使用。

Parameters:

indices (Tensor | None) – 形状: s 索引,或None。如果为None,则解释为 torch.arange(self.max_id)(尽管实现更高效)。

Returns:

形状: (*s, *self.shape) 表示形式。

Return type:

Tensor

iter_extra_repr() Iterable[str][source]

遍历组件以用于 extra_repr()

Return type:

Iterable[str]

post_parameter_update()[source]

应用不应包含在梯度中的约束。

reset_parameters() None[source]

重置模块的参数。

Return type: