RESCAL交互

class RESCALInteraction(*args, **kwargs)[source]

基础类: Interaction[Tensor, Tensor, Tensor]

无状态的RESCAL交互函数。

对于头和尾实体表示 \(\mathbf{h}, \mathbf{t} \in \mathbb{R}^d\) 和关系表示 \(\mathbf{R} \in \mathbb{R}^{d \times d}\),交互函数给出为

\[\mathbf{h}^T \textbf{R} \textbf{t} = \sum_{i=1}^{d} \sum_{j=1}^{d} \mathbf{h}_i \mathbf{R}_{i, j} \mathbf{t}_{i}\]

因此,关系矩阵 \(\textbf{R}\) 包含权重 \(\textbf{R}_{i, j}\),这些权重捕捉了头部表示的第 \(i\) 个潜在因子与第 \(j\) 个潜在因子之间的交互量。

计算复杂度由\(\mathcal{O}(d^2)\)给出。

初始化内部模块状态,由nn.Module和ScriptModule共享。

属性摘要

relation_shape

关系表示的符号形状

方法总结

forward(h, r, t)

评估交互函数。

属性文档

relation_shape: Sequence[str] = ('dd',)

关系表示的符号形状

方法文档

forward(h: Tensor, r: Tensor, t: Tensor) Tensor[来源]

评估交互函数。

另请参阅

Interaction.forward 提供了关于交互函数通用批处理形式的详细描述。

Parameters:
  • h (Tensor) – 形状: (*batch_dims, d) 头部表示。

  • r (Tensor) – 形状: (*batch_dims, d) 关系表示。

  • t (Tensor) – 形状: (*batch_dims, d) 尾部表示。

Returns:

形状: batch_dims 分数。

Return type:

Tensor