在 ShuffleNet 中计算精度函数

Calculate accuracy function in ShuffleNet

我正在使用 ShuffleNet 的一些代码,但我无法理解此函数中 correct 的计算。(此函数计算精度 1 和 5)。
正如我在第三行中理解的 pred 是索引,但我不明白为什么两行之后用等价函数将它与 target 进行了比较,因为 pred 是索引最大概率输出。

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))
    return res

查看代码,我可以推测 output 的形状为 (batch_size, n_logits),而目标是密集表示:形状为 (batch_size, 1)。这意味着真实值 class 由整数值指定:相应的 class 标签。

如果我们研究 top-k 准确率的实现,我们首先需要了解这一点:top-k 准确率是关于计算 k 我们输出的最高预测。它本质上是标准 top-1 精度的一种概括形式,我们只会查看单个最高预测并确定它是否与目标匹配。

如果我们以 batch_size=2n_logits=10k=3 为例,我们对顶部感兴趣-3准确度。这里我们抽取一个随机预测:

>>> output
tensor([[0.2110, 0.9992, 0.0597, 0.9557, 0.8316, 0.8407, 0.8398, 0.3631, 0.2889, 0.3226],
        [0.6811, 0.2932, 0.2117, 0.6522, 0.2734, 0.8841, 0.0336, 0.7357, 0.9232, 0.2633]])

我们首先查看 k 最高对数并检索它们的索引:

>>> _, pred = output.topk(k=3, dim=1, largest=True, sorted=True)

>>> pred
tensor([[3, 6, 4],
        [7, 3, 5]])

这只不过是一个切片 torch.argsortoutput.argsort(1, descending=True)[:, :3] 将 return 得到相同的结果。

然后我们可以转置以获得最后的批次 (3, 2):

>>> pred = pred.T
tensor([[3, 7],
        [6, 3],
        [4, 5]])

现在我们有了每个批次元素的 top-k 预测,我们需要将它们与基本事实进行比较。现在让我们想象一个目标张量(记住形状为(batch_size=2, 1)):

>>> target
tensor([[1],
        [5]]) 

我们首先要把它展开成pred的形状:

>>> target.view(1, -1).expand_as(pred)
tensor([[1, 0],
        [1, 0],
        [1, 0]])

然后我们将彼此与 torch.eq 进行比较,元素相等运算符:

>>> correct = torch.eq(pred, target.view(1, -1).expand_as(pred))
tensor([[False, False],
        [False, False],
        [False,  True]])

正如您在第二个批处理元素中看到的那样,最高的三个元素之一与真实 class 标签 5 相匹配。在第一个 batch 元素上,三个最高的预测都不匹配 ground-truth 标签,这是不正确的。第二批元素算一个'correct'.

当然,基于这个等式掩码张量 correct,您可以对其进行更多切片,以计算其他最高 k' 精度,其中 k' <= k。例如 k' = 1:

>>> correct[:1]
tensor([[False, False]])

对于最高 1 的准确度,我们在两个批处理元素中有零个正确实例。