torch_geometric.explain.metric.groundtruth_metrics
- groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]][source]
比较并评估解释掩码与真实解释掩码。
- Parameters:
pred_mask (torch.Tensor) – 要评估的预测掩码。
target_mask (torch.Tensor) – 真实的目标掩码。
metrics (str 或 List[str], 可选) – 要返回的指标 (
"accuracy","recall","precision","f1_score","auroc"). (默认:["accuracy", "recall", "precision", "f1_score", "auroc"])threshold (float, optional) – 用于对
mask和groundtruth进行硬阈值处理的阈值。 (默认值:0.5)
- Return type: