NTN交互

class NTNInteraction(activation: str | Module | type[Module] | None = None, activation_kwargs: Mapping[str, Any] | None = None)[source]

基础类:Interaction[Tensor, tuple[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]

无状态的神经张量网络(NTN)交互函数。

它由以下给出

\[\mathbf{r}_{u}^{T} \cdot \sigma( \mathbf{h} \mathbf{R}_{3} \mathbf{t} + \mathbf{R}_{2} [\mathbf{h};\mathbf{t}] + \mathbf{r}_1 )\]

使用 \(\mathbf{W}_3 \in \mathbb{R}^{d \times d \times k}\), \(\textbf{R}_2 \in \mathbb{R}^{k \times 2d}\), 偏置向量 \(\textbf{r}_1\), 最终投影 \(\textbf{r}_u \in \mathbb{R}^k\), 和一个非线性激活 函数 \(\sigma\) (默认为 Tanh).

它可以被视为具有关系特定权重的两层MLP的扩展,并在输入层中增加了一个双线性张量。 为每个关系单独参数化的神经网络使模型非常具有表现力,但也计算成本高昂(\(\mathcal{O}(kd^2)\))。

注意

我们将原始的\(k \times 2d\)维的\(\mathbf{R}_2\)矩阵分成两个形状为\(k \times d\)的部分,以支持更高效的1:n评分,例如在score_h()score_t()设置中。

使用给定的非线性激活函数初始化NTN。

Parameters:
  • activation (HintOrType[nn.Module]) – 一个非线性激活函数。默认为双曲正切函数 torch.nn.Tanh 如果 None

  • activation_kwargs (Mapping[str, Any] | None) – 如果 activation 作为类传递,这些关键字参数在其实例化期间使用。

注意

参数对 (activation, activation_kwargs) 用于 class_resolver.contrib.torch.activation_resolver

解析器的解释及其使用方法在 https://class-resolver.readthedocs.io/en/latest/中给出。

属性摘要

relation_shape

关系表示的符号形状

方法总结

forward(h, r, t)

评估交互函数。

属性文档

relation_shape: Sequence[str] = ('kdd', 'kd', 'kd', 'k', 'k')

关系表示的符号形状

方法文档

forward(h: Tensor, r: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], t: Tensor) Tensor[来源]

评估交互函数。

另请参阅

Interaction.forward 提供了关于交互函数通用批处理形式的详细描述。

Parameters:
  • h (Tensor) – 形状: (*batch_dims, d) 头部表示。

  • r (tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) – 形状: (*batch_dims, k, d, d), (*batch_dims, k, d), (*batch_dims, k, d), (*batch_dims, k), 和 (*batch_dims, k) 关系表示。

  • t (Tensor) – 形状: (*batch_dims, d) 尾部表示。

Returns:

形状: batch_dims 分数。

Return type:

Tensor