如何使用 Pytorch and/or Numpy 高效地查找多维矩阵数组中的最大值索引

How to Efficiently Find the Indices of Max Values in a Multidimensional Array of Matrices using Pytorch and/or Numpy

背景

在机器学习中处理高维数据很常见。例如,在卷积神经网络 (CNN) 中,每个输入图像的尺寸可能是 256x256,每个图像可能有 3 个颜色通道(红色、绿色和蓝色)。如果我们假设模型一次接收一批 16 张图像,则进入我们 CNN 的输入的维数是 [16,3,256,256]。每个单独的卷积层都需要 [batch_size, in_channels, in_y, in_x] 形式的数据,并且所有这些量通常在层与层之间发生变化(batch_size 除外)。我们对由 [in_y, in_x] 值组成的矩阵使用的术语是 feature map,这个问题是关于在每个特征图中找到最大值及其索引在给定层。

为什么我要这样做?我想对每个特征图应用一个遮罩,我想在每个特征图中应用以最大值为中心的遮罩,为此我需要知道每个最大值在哪里位于。此掩码应用程序是在模型的训练和测试期间完成的,因此效率对于减少计算时间至关重要。有许多 Pytorch 和 Numpy 解决方案可用于查找单例最大值和索引,以及用于查找单个维度上的最大值或索引,但没有(我能找到的)专用和 高效 构建-in 函数用于一次查找 2 个或更多维度的最大值索引。是的,我们可以嵌套在单个维度上运行的函数,但这些是一些效率最低的方法。

我试过的

性能标准

如果我问的是关于效率的问题,我需要清楚地详细说明期望。我正在尝试为上述问题找到一个省时的解决方案(space 是次要的),而无需编写 C code/extensions,并且相当灵活(超专业化方法不是我所追求的) .该方法必须接受数据类型为 float32 或 float64 的 [a,b,c,d] Torch 张量作为输入,并输出数据类型为 int32 或 int64 的 [a,b,2] 形式的数组或张量(因为我们使用输出作为索引)。 解决方案应针对以下典型解决方案进行基准测试:

max_indices = torch.stack([torch.stack([(x[k][j]==torch.max(x[k][j])).nonzero()[0] for j in range(x.size()[1])]) for k in range(x.size()[0])])

方法

我们将利用 Numpy 社区和库,以及 Pytorch 张量和 Numpy 数组可以相互转换 to/from 这一事实,而无需在内存中复制或移动底层数组(所以转换成本低)。来自 Pytorch documentation:

Converting a torch Tensor to a Numpy array and vice versa is a breeze. The torch Tensor and Numpy array will share their underlying memory locations, and changing one will change the other.

解决方案一

我们将首先使用 Numba library 编写一个函数,该函数将在首次使用时进行 just-in-time (JIT) 编译,这意味着我们无需编写 C 代码即可获得 C 速度我们自己。当然,对于什么可以得到 JIT-ed 有一些注意事项,其中一个注意事项是我们使用 Numpy 函数。但这还不算太糟糕,因为请记住,从我们的 torch 张量转换为 Numpy 的成本很低。我们创建的函数是:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx

此函数来自位于 here 的另一个 Whosebug 答案(这是将我介绍给 Numba 的答案)。该函数采用 N-Dimensional Numpy 数组并查找给定 item 的第一次出现。它立即 return 找到成功匹配项的索引。 @njit 装饰器是 @jit(nopython=True) 的缩写,它告诉编译器我们希望它使用 no Python 对象编译函数,并抛出如果无法这样做,则会出现错误(当不使用 Python 对象时,Numba 是最快的,速度就是我们所追求的)。

有了这个快速函数的支持,我们可以得到张量中最大值的索引,如下所示:

import numpy as np

x =  x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
    max_indices[index] = np.asarray(indexFunc(x[index], maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)

我们使用 np.amax 因为它的 axis 参数可以接受一个元组,允许它 return 4D 输入中每个 2D 特征图的最大值。我们提前用np.zeros初始化了max_indices,因为appending to numpy arrays is expensive,所以我们提前分配了我们需要的space。这种方法比问题中的典型解决方案 快(一个数量级),但它还在 JIT-ed 函数之外使用了 for 循环,所以我们可以改进...

方案二

我们将使用以下解决方案:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))

我们可以使用 Numba 的 prange 函数(其行为与 range 完全相同),而不是使用 for 循环遍历我们的特征映射 one-at-a-time但告诉编译器我们希望循环被并行化)和 parallel=True 装饰器参数。 Numba 也 parallelizes the np.zeros function。因为我们的函数是编译 Just-In-Time 并且没有使用 Python 对象,所以 Numba 可以利用我们系统中所有可用的线程!值得注意的是现在indexFunc里面多了一个raise RuntimeError。我们需要包含它,否则 Numba 编译器将尝试推断函数的 return 类型并推断它将是数组或 None。这与我们在 indexFunc2 中的用法不一致,因此编译器会抛出错误。当然,从我们的设置 我们知道 indexFunc 总是 return 一个数组,所以我们可以简单地在另一个逻辑分支中引发和错误。

此方法在功能上与解决方案一相同,但将使用 nd.index 的迭代更改为使用 prange 的两个 for 循环。这种方法比解决方案一快 4 倍。

方案三

解决方案二很快,但它仍然使用正则 Python 寻找最大值。我们可以使用更全面的 JIT-ed 函数加快速度吗?

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc3(x):
    maxVals = np.zeros((x.shape[0],x.shape[1]),dtype=np.float32)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxVals[i][j] = np.max(x[i][j])
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            x[i][j] == np.max(x[i][j])
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

max_indices = torch.from_numpy(indexFunc3(x))

这个解决方案看起来可能还有很多其他内容,但唯一的变化是,我们现在已将操作并行化,而不是使用 np.amax 计算每个特征图的最大值。这种方法比解决方案二稍微快一些。

方案四

这个解决方案是我能想到的最好的解决方案:

@njit(cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices

max_indices = torch.from_numpy(indexFunc4(x))

此方法更简洁,速度也最快,比解决方案三快 33%,比典型解决方案快 50 倍。我们使用 np.argmax 来获取每个特征图的最大值的索引, 但是 np.argmax 只有 return 索引就像每个特征图一样被夷为平地。也就是说,我们得到一个整数告诉我们元素在我们的特征映射中的编号,而不是我们需要能够访问该元素的索引。数学 [maxTemp // x.shape[2], maxTemp % x.shape[2]] 是将那个单数 int 变成我们需要的 [row,column]

基准测试

所有方法都针对形状为 [32,d,64,64] 的随机输入进行基准测试,其中 d 从 5 增加到 245。对于每个 d,收集 15 个样本并对时间进行平均。相等性测试确保所有解决方案都提供相同的值。基准输出的一个例子是:

随着 d 的增加,基准测试时间的图表是(省略了典型的解决方案,因此图表没有被压扁):

哇!这些尖峰一开始是怎么回事?

解决方法五

Numba 允许我们生成 Just-In-Time 编译函数,但直到我们第一次使用它们时它才会编译它们;然后它会缓存结果以供我们访问时使用l 功能再次。这意味着我们第一次调用我们的 JIT-ed 函数时,我们会在编译函数时得到计算时间的峰值。幸运的是,有一种方法可以解决这个问题——如果我们提前指定函数的 return 类型和参数类型,函数将被提前编译而不是编译 just-in-time。将这些知识应用于解决方案四,我们得到:

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices    

max_indices6 = torch.from_numpy(indexFunc4(x))

如果我们重新启动内核并重新运行我们的基准测试,我们可以查看第一个结果 d==5 和第二个结果 d==10 并注意所有 JIT-ed d==5 时解决方案较慢,因为它们必须被编译,解决方案四除外,因为我们提前明确提供了函数签名:

我们开始了!这是迄今为止我对这个问题的最佳解决方案。


编辑#1

解法六

已开发出改进的解决方案,比之前发布的最佳解决方案快 33%。此解决方案仅在输入数组为 C-contiguous 时有效,但这不是一个大限制,因为 numpy 数组或火炬张量将是连续的,除非它们被重塑,并且两者都具有使 array/tensor 连续的功能如果需要的话。

这个解决方案与之前的最佳解决方案相同,但是指定输入和 return 类型的函数装饰器从

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)

唯一的区别是每个数组类型中的最后一个 : 变为 ::1,这会向 numba njit 编译器发出信号,表明输入数组是 C-contiguous,从而使其更好优化。

那么完整的解法六就是:

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
def indexFunc5(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices 

max_indices7 = torch.from_numpy(indexFunc5(x))

包含这个新解决方案的基准测试证实了加速: