torch_geometric.nn.kge.ComplEx

class ComplEx(num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False)[source]

Bases: KGEModel

来自“Complex Embeddings for Simple Link Prediction”论文的ComplEx模型。

ComplEx 将关系建模为使用 Hermetian 点积的头实体和尾实体之间的复值双线性映射。 实体和关系嵌入在不同维度的空间中,从而得到评分函数:

\[d(h, r, t) = Re(< \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t>)\]

注意

有关使用 ComplEx 模型的示例,请参见 examples/kge_fb15k_237.py

Parameters:
  • num_nodes (int) – The number of nodes/entities in the graph.

  • num_relations (int) – The number of relations in the graph.

  • hidden_channels (int) – The hidden embedding size.

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the embedding matrices will be sparse. (default: False)

reset_parameters()[source]

重置模块的所有可学习参数。

forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

返回给定三元组的分数。

Parameters:
Return type:

Tensor

loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

返回给定三元组的损失值。

Parameters:
Return type:

Tensor