torch_geometric.explain.metric.groundtruth_metrics

groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]][source]

比较并评估解释掩码与真实解释掩码。

Parameters:
  • pred_mask (torch.Tensor) – 要评估的预测掩码。

  • target_mask (torch.Tensor) – 真实的目标掩码。

  • metrics (strList[str], 可选) – 要返回的指标 ("accuracy", "recall", "precision", "f1_score", "auroc"). (默认: ["accuracy", "recall", "precision", "f1_score", "auroc"])

  • threshold (float, optional) – 用于对 maskgroundtruth 进行硬阈值处理的阈值。 (默认值: 0.5)

Return type:

Union[float, Tuple[float, ...]]