speechbrain.utils.Accuracy 模块
计算准确率。
作者 * Jianyuan Zhong 2020
摘要
类:
用于计算整体一步向前预测准确率的模块。 |
函数:
计算一批预测的对数概率和目标的准确性。 |
参考
- speechbrain.utils.Accuracy.Accuracy(log_probabilities, targets, length=None)[source]
计算一批预测对数概率和目标的准确性。
- Parameters:
log_probabilities (torch.Tensor) – 预测的对数概率 (batch_size, time, feature)。
targets (torch.Tensor) – 目标 (batch_size, time).
length (torch.Tensor) – 目标的长度 (batch_size,).
- Returns:
numerator (float) – 正确样本的数量
denominator (float) – 样本的总数
Example
>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> acc = Accuracy(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> print(acc) (1.0, 2.0)
- class speechbrain.utils.Accuracy.AccuracyStats[source]
基础类:
object用于计算整体一步向前预测准确率的模块。
Example
>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> stats = AccuracyStats() >>> stats.append(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> acc = stats.summarize() >>> print(acc) 0.5