求和技巧的梯度如何在 keras 中获得最大池化位置?

How does the gradient of the sum trick work to get maxpooling positions in keras?

keras 示例目录包含他们在 MNIST 数据上训练的堆叠式自动编码器 (SWWAE) 的轻量级版本。 (https://github.com/fchollet/keras/blob/master/examples/mnist_swwae.py)

在最初的 SWWAE 论文中,作者使用软函数计算了 what 和 where。然而,在 keras 实现中,他们使用了一个技巧来获取这些位置。我想明白这个技巧。

这是技巧的代码。

def getwhere(x):
    ''' Calculate the 'where' mask that contains switches indicating which
    index contained the max value when MaxPool2D was applied.  Using the
    gradient of the sum is a nice trick to keep everything high level.'''
    y_prepool, y_postpool = x
    return K.gradients(K.sum(y_postpool), y_prepool)  # How exactly does this line work?

其中 y_prepool 是一个 MxN 矩阵,y_postpool 是一个 M/2 x N/2 矩阵(假设典型池化大小为 2 个像素)。

我已经验证 getwhere() 的输出是一个钉床矩阵,其中钉子指示最大值的位置(如果您愿意,则为局部 argmax)。

有人可以构建一个小示例来演示 getwhere 如何使用此 "Trick?"

让我们关注最简单的例子,不真正谈论卷积,假设我们有一个向量

x = [1 4 2]

我们最大池化(用一个大的window),我们得到

mx = 4

从数学上来说,就是:

mx = x[argmax(x)]

现在,"trick"恢复一个池化使用的热掩码是

magic = d mx / dx

argmax 没有梯度,但是它 "passes" 对应向量中最大元素位置的一个元素的梯度,所以:

d mx / dx = [0/dx[1] dx[2]/dx[2] 0/dx[3]] = [0 1 0]

如你所见,所有非最大元素的梯度都为零(由于argmax),并且“1”出现在最大值处,因为dx/x = 1.

现在,对于 "proper" maxpool,您有许多池化区域,连接到许多输入位置,因此采用池化值总和的类似梯度,将恢复所有索引。

但是请注意,如果您有大量重叠的内核,此技巧将不起作用 - 您最终可能会得到比“1”更大的值。基本上,如果一个像素被 K 个内核最大池化,那么它的值将是 K,而不是 1,例如:

     [1 ,2, 3]
x =  [13,3, 1]
     [4, 2, 9]

如果我们用 2x2 window 最大池,我们得到

mx = [13,3]
     [13,9]

渐变技巧给你

        [0, 0, 1]
magic = [2, 0, 0]
        [0, 0, 1]