• Docs >
  • Use Metrics in TorchEval
Shortcuts

在TorchEval中使用指标

PyTorch 评估指标是 TorchEval 的核心功能之一。 对于大多数指标,我们提供了基于类的有状态接口,这些接口仅在被告知计算指标时累积必要的数据,以及纯函数式接口。

类指标

类 metrics 跟踪指标状态,这使得它们能够通过跨多个进程的累积和同步来计算值。基类是 torcheval.metrics.Metric

类指标的核心API是update()compute()reset()

  • update(): 使用输入数据更新度量状态。当需要添加新数据以进行度量计算时,通常会使用此方法。

  • compute(): 从度量状态计算度量值,这些状态由之前的update()调用更新。计算频率可以低于更新频率。

  • reset(): 将度量状态变量重置为其默认值。通常在每个epoch结束时调用此函数以清理度量状态。

注意

类指标跟踪由传递给update()调用的输入数据更新的内部状态。这意味着指标状态应移动到与输入数据相同的设备。您可以在初始化时直接传入设备或使用to(device) API。.device属性显示指标状态的设备。

下面是一个在简单训练脚本中使用类度量的示例。

import torch
from torcheval.metrics import MulticlassAccuracy

device = "cuda" if torch.cuda.is_available() else "cpu"
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches, batch_size = 4, 8, 10
num_classes = 3

# number of batches between metric computations
compute_frequency = 2

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        input = torch.randint(high=num_classes, size=(batch_size,), device=device)
        target = torch.randint(high=num_classes, size=(batch_size,), device=device)

        # metric.update() updates the metric state with new data
        metric.update(input, target)

        if (batch_idx + 1) % compute_frequency == 0:
                print(
                    "Epoch {}/{}, Batch {}/{} --- acc: {:.4f}".format(
                        epoch + 1,
                        num_epochs,
                        batch_idx + 1,
                        num_batches,
                        # metric.compute() returns metric value from all seen data
                        metric.compute(),
                    )
                )

    # metric.reset() reset metric states. It's typically called after the epoch completes.
    metric.reset()

保存和加载指标

类指标也实现了有状态协议,.state_dict().load_state_dict()。这些函数可用于保存和加载指标。

import torch
from torcheval.metrics import MulticlassAccuracy

metric = MulticlassAccuracy()
input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
metric.update(input, target)

state_dict = metric.state_dict()
loaded_metric = MulticlassAccuracy()
loaded_metric.load_state_dict(state_dict)

# returns torch.tensor(0.5)
loaded_metric.compute()

功能指标

功能指标是简单的python函数,它们从输入数据中计算指标值。它们是轻量级的,并且相对较快,因为它们不需要保持和操作指标状态。 下面的示例展示了如何使用功能版本计算指标值。

import torch
from torcheval.metrics.functional import multiclass_accuracy

input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
# returns torch.tensor(0.5)
multiclass_accuracy(input, target)