torchmetric 使用阈值计算精度

torchmetric calculate accuracy with threshold

torchmetrics.Accuracy threshold 关键字如何工作?我有以下设置:

import torch, torchmetrics

preds = torch.tensor([[0.3600, 0.3200, 0.3200]])
target = torch.tensor([0])
torchmetrics.functional.accuracy(preds, target, threshold=0.5, num_classes=3)

输出:

tensor(1.)

None preds 中的值概率高于 0.5,为什么它给出 100% 准确率?

threshold=0.50.5 下的每个概率设置为 0。它仅在您处理二进制(这不是您的情况,因为 num_classes=3)或多标签分类(似乎不是这种情况,因为未设置 multiclass)的情况下使用。因此threshold实际上并没有涉及。

在您的例子中,preds 表示与一项观察相关的预测。它的最大概率(0.36)设置在位置0,因此argmax(preds)和target等于0,因此accuracy设置为1,因为(仅) 观察正确。

您可以使用 subset_accuracy=True 的多标签目标形状,如 here

所述
import torch, torchmetrics

target = torch.tensor([0,1,2])
target = torch.nn.functional.one_hot(target)
preds = torch.tensor([[0.3600, 0.3200, 0.3200],
                      [0.300,  0.4000, 0.3000],
                      [0.3000, 0.2000, 0.5000]])

torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.6)

输出:

tensor(0.)

降低阈值:

torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.5)

输出:

tensor(0.333)

将阈值降低到 0.35:

torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.35)

输出:

tensor(1.)

不幸的是,如果任何其他概率高于阈值,这将不起作用,因为它会认为您正在预测多个标签。例如阈值 0.0 将给出:

torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.0)

输出:

tensor(0.)