torch_geometric.nn.kge.KGEModel
- class KGEModel(num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False)[source]
Bases:
Module用于实现自定义KGE模型的抽象基类。
- Parameters:
- forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的分数。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type:
- loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]
返回给定三元组的损失值。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type:
- loader(head_index: Tensor, rel_type: Tensor, tail_index: Tensor, **kwargs) Tensor[source]
返回一个采样三元组子集的小批量加载器。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
**kwargs (可选) –
torch.utils.data.DataLoader的额外参数,例如batch_size,shuffle,drop_last或num_workers.
- Return type:
- test(head_index: Tensor, rel_type: Tensor, tail_index: Tensor, batch_size: int, k: int = 10, log: bool = True) Tuple[float, float, float][source]
通过计算所有可能尾部实体的平均排名(Mean Rank)、平均倒数排名(MRR)和命中率(Hits@:math:k)来评估模型质量。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
batch_size (int) – 用于评估的批量大小。
k (int, optional) – Hits @ \(k\) 中的 \(k\)。 (默认值:
10)
- Return type:
- random_sample(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tuple[Tensor, Tensor, Tensor][source]
通过替换头部或尾部(但不是两者)随机采样负三元组。
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type: