Shortcuts

torcheval.metrics.functional.binary_recall

torcheval.metrics.functional.binary_recall(input: Tensor, target: Tensor, *, threshold: float = 0.5) Tensor

计算二分类类别的召回率得分,该得分计算为真正例(TP)的数量与实际正例总数(TP + FN)的比率。 其类版本为torcheval.metrics.BinaryRecall

Parameters:
  • input (Tensor) – 预测标签/逻辑/概率的张量,形状为 (n_sample, )。

  • target (Tensor) – 形状为 (n_sample, ) 的真实标签张量。

  • threshold (float, 默认 0.5) – 用于将输入转换为每个样本的预测标签的阈值。 torch.where(input < threshold, 0, 1) 将应用于 input

示例:

>>> import torch
>>> from torcheval.metrics.functional.classification import binary_recall
>>> input = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([0, 1, 1, 1])
>>> binary_recall(input, target)
tensor(0.6667)  # 2 / 3
>>> input = torch.tensor([0, 0.2, 0.4, 0.7])
>>> target = torch.tensor([1, 0, 1, 1])
>>> binary_recall(input, target)
tensor(0.3333)  # 1 / 3
>>> binary_recall(input, target, threshold=0.4)
tensor(0.5000)  # 1 / 2