torch_geometric.nn.kge.TransE
- class TransE(num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, p_norm: float = 1.0, sparse: bool = False)[source]
基础类:
KGEModel来自“Translating Embeddings for Modeling Multi-Relational Data”论文的TransE模型。
TransE将关系建模为从头实体到尾实体的翻译,使得\[\mathbf{e}_h + \mathbf{e}_r \approx \mathbf{e}_t,\]导致评分函数:
\[d(h, r, t) = - {\| \mathbf{e}_h + \mathbf{e}_r - \mathbf{e}_t \|}_p\]注意
有关使用
TransE模型的示例,请参见 examples/kge_fb15k_237.py。- Parameters:
- forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的分数。
- Parameters:
head_index (torch.Tensor) – 头部索引。
rel_type (torch.Tensor) – 关系类型。
tail_index (torch.Tensor) – 尾部索引。
- Return type:
- loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的损失值。
- Parameters:
head_index (torch.Tensor) – 头部索引。
rel_type (torch.Tensor) – 关系类型。
tail_index (torch.Tensor) – 尾部索引。
- Return type: