torch_geometric.nn.kge.DistMult
- class DistMult(num_nodes: int, num_relations: int, hidden_channels: int, margin: float = 1.0, sparse: bool = False)[source]
Bases:
KGEModel来自“知识库中嵌入实体和关系以进行学习和推理”论文的DistMult模型。
DistMult将关系建模为对角矩阵,这简化了头实体和尾实体之间的双线性交互,得分函数为:\[d(h, r, t) = < \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t >\]注意
有关使用
DistMult模型的示例,请参见 examples/kge_fb15k_237.py。- Parameters:
num_nodes (int) – The number of nodes/entities in the graph.
num_relations (int) – The number of relations in the graph.
hidden_channels (int) – The hidden embedding size.
margin (float, optional) – 排名损失的边距。 (默认值:
1.0)sparse (bool, optional) – If set to
True, gradients w.r.t. to the embedding matrices will be sparse. (default:False)
- forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的分数。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type:
- loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的损失值。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type: