torch_geometric.nn.kge.TransE

class TransE(num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, p_norm: float = 1.0, sparse: bool = False)[source]

基础类:KGEModel

来自“Translating Embeddings for Modeling Multi-Relational Data”论文的TransE模型。

TransE 将关系建模为从头实体到尾实体的翻译,使得

\[\mathbf{e}_h + \mathbf{e}_r \approx \mathbf{e}_t,\]

导致评分函数:

\[d(h, r, t) = - {\| \mathbf{e}_h + \mathbf{e}_r - \mathbf{e}_t \|}_p\]

注意

有关使用 TransE 模型的示例,请参见 examples/kge_fb15k_237.py

Parameters:
  • num_nodes (int) – 图中节点/实体的数量。

  • num_relations (int) – 图中关系的数量。

  • hidden_channels (int) – 隐藏嵌入大小。

  • margin (int, optional) – 排序损失的边距。 (默认值: 1.0)

  • p_norm (int, optional) – 嵌入和距离归一化的顺序。 (默认: 1.0)

  • sparse (bool, 可选) – 如果设置为 True,关于嵌入矩阵的梯度将是稀疏的。(默认值:False

reset_parameters()[source]

重置模块的所有可学习参数。

forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

返回给定三元组的分数。

Parameters:
Return type:

Tensor

loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

返回给定三元组的损失值。

Parameters:
Return type:

Tensor