torch_geometric.metrics.LinkPredF1

class LinkPredF1(k: int)[source]

基础类: LinkPredMetric

一个用于计算F1 @ \(k\)的链接预测指标。

Parameters:

k (int) – 用于评估的前\(k\)个预测的数量。

update(pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]) None

根据当前小批量预测更新状态变量。

update() 可以重复多次以累积连续预测的结果,例如,在小批量训练或评估循环中。

Parameters:
  • pred_index_mat (torch.Tensor) – 每个小批量样本的 top-\(k\) 预测,形状为 [batch_size, k]

  • edge_label_index (torch.Tensor) – 每个小批量样本的真实索引,以COO格式给出,形状为 [2, num_ground_truth_indices]

Return type:

None

compute() Tensor

计算最终的度量值。

Return type:

Tensor

reset() None

将度量状态变量重置为其默认值。

Return type:

None