torchmetric 使用阈值计算精度
torchmetric calculate accuracy with threshold
torchmetrics.Accuracy
threshold 关键字如何工作?我有以下设置:
import torch, torchmetrics
preds = torch.tensor([[0.3600, 0.3200, 0.3200]])
target = torch.tensor([0])
torchmetrics.functional.accuracy(preds, target, threshold=0.5, num_classes=3)
输出:
tensor(1.)
None preds
中的值概率高于 0.5
,为什么它给出 100% 准确率?
threshold=0.5
将 0.5
下的每个概率设置为 0
。它仅在您处理二进制(这不是您的情况,因为 num_classes=3
)或多标签分类(似乎不是这种情况,因为未设置 multiclass
)的情况下使用。因此threshold
实际上并没有涉及。
在您的例子中,preds
表示与一项观察相关的预测。它的最大概率(0.36
)设置在位置0
,因此argmax(preds)和target等于0
,因此accuracy设置为1,因为(仅) 观察正确。
您可以使用 subset_accuracy=True
的多标签目标形状,如 here
所述
import torch, torchmetrics
target = torch.tensor([0,1,2])
target = torch.nn.functional.one_hot(target)
preds = torch.tensor([[0.3600, 0.3200, 0.3200],
[0.300, 0.4000, 0.3000],
[0.3000, 0.2000, 0.5000]])
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.6)
输出:
tensor(0.)
降低阈值:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.5)
输出:
tensor(0.333)
将阈值降低到 0.35:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.35)
输出:
tensor(1.)
不幸的是,如果任何其他概率高于阈值,这将不起作用,因为它会认为您正在预测多个标签。例如阈值 0.0 将给出:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.0)
输出:
tensor(0.)
torchmetrics.Accuracy
threshold 关键字如何工作?我有以下设置:
import torch, torchmetrics
preds = torch.tensor([[0.3600, 0.3200, 0.3200]])
target = torch.tensor([0])
torchmetrics.functional.accuracy(preds, target, threshold=0.5, num_classes=3)
输出:
tensor(1.)
None preds
中的值概率高于 0.5
,为什么它给出 100% 准确率?
threshold=0.5
将 0.5
下的每个概率设置为 0
。它仅在您处理二进制(这不是您的情况,因为 num_classes=3
)或多标签分类(似乎不是这种情况,因为未设置 multiclass
)的情况下使用。因此threshold
实际上并没有涉及。
在您的例子中,preds
表示与一项观察相关的预测。它的最大概率(0.36
)设置在位置0
,因此argmax(preds)和target等于0
,因此accuracy设置为1,因为(仅) 观察正确。
您可以使用 subset_accuracy=True
的多标签目标形状,如 here
import torch, torchmetrics
target = torch.tensor([0,1,2])
target = torch.nn.functional.one_hot(target)
preds = torch.tensor([[0.3600, 0.3200, 0.3200],
[0.300, 0.4000, 0.3000],
[0.3000, 0.2000, 0.5000]])
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.6)
输出:
tensor(0.)
降低阈值:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.5)
输出:
tensor(0.333)
将阈值降低到 0.35:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.35)
输出:
tensor(1.)
不幸的是,如果任何其他概率高于阈值,这将不起作用,因为它会认为您正在预测多个标签。例如阈值 0.0 将给出:
torchmetrics.functional.accuracy(preds, target, subset_accuracy=True, threshold=0.0)
输出:
tensor(0.)