使用 Numba 在 nd 数组上并行化最大值

Parallelizing a maximum over an nd-array using Numba

我正在尝试使用 Numba 并行化一个 Python 函数,该函数采用两个 numpy ndarrays,alphabeta, 作为参数。它们分别具有 (a,m,n)(b,m,n) 形式的形状,因此可以在后面的维度上广播。该函数计算参数的二维切片的矩阵点积(Frobenius 乘积),并找到使每个 alpha 切片的乘积最大化的 beta 切片。在代码中:

@njit(parallel=True)
def parallel_value(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
return values,indices

在没有 njit 装饰器的情况下运行良好,但 Numba 编译器抱怨:

No implementation of function Function(<built-in function setitem>) found for signature:

>>>setitem(array(float64, 1d, C), int64, array(float64, 1d, C))

违规行显然是 values[i]=dot[index]。我不知道为什么这是有问题的。此问题的原因是什么,我该如何解决?

此外,将 nogil=True 添加到 @njit 的参数中会有什么好处吗?

我成功重现了你的问题。当运行代码:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_value(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices


a, b, m, n = 6, 5, 4, 3
parallel_value(np.random.rand(a, m, n), np.random.rand(b, m, n))

我收到错误消息:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
 
There are 16 candidate implementations:
      - Of which 16 did not match due to:
      Overload of function 'setitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, C), int64, array(float64, 1d, C))':
       No match.

During: typing of setitem at <ipython-input-41-44518cf5219f> (11)

File "<ipython-input-41-44518cf5219f>", line 11:
def parallel_value(alpha,beta):
    <source elided>
        index=np.argmax(dot)
        values[i]=dot[index]
        ^

根据 GitHub 页面中的 this issue,numba 中的点操作可能存在问题。

当我使用显式循环重写代码时,它似乎可以工作:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_value_numba(alpha,beta):
    values  = np.empty(alpha.shape[0])
    indices = np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot = np.zeros(beta.shape[0])
        for j in prange(beta.shape[0]):
            for k in prange(beta.shape[1]):
                for l in prange(beta.shape[2]):
                    dot[j] += alpha[i,k,l]*beta[j, k, l]
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices

def parallel_value_nonumba(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices


a, b, m, n = 6, 5, 4, 3
np.random.seed(42)
A = np.random.rand(a, m, n)
B = np.random.rand(b, m, n)
res_num   = parallel_value_numba(A, B)
res_nonum = parallel_value_nonumba(A, B)

print(f'res_num = {res_num}')
print(f'res_nonum = {res_nonum}')

输出:

res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1., 3., 1., 1., 1., 1.]))
res_nonum = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1., 3., 1., 1., 1., 1.]))

据我所知,显式循环似乎不会影响性能。虽然我无法将它与没有它们的 运行 相同代码进行比较,因为这是 numba,但我猜这无关紧要:

%timeit res_num   = parallel_value_numba(A, B)
%timeit res_nonum = parallel_value_nonumba(A, B)

输出:

The slowest run took 1472.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.92 µs per loop
10000 loops, best of 5: 76.9 µs per loop

最后,您可以通过矢量化您正在使用的代码,仅使用 numpy 更有效地完成它。它几乎与具有显式循环的 numba 一样快,并且您不会有初始编译延迟。以下是您的操作方法:

def parallel_value_np(alpha,beta):
    alpha   = alpha.reshape(alpha.shape[0], 1, alpha.shape[1], alpha.shape[2])
    beta    = beta.reshape(1, beta.shape[0], beta.shape[1], beta.shape[2])
    dot     = np.sum(alpha*beta, axis=(2,3))
    indices = np.argmax(dot, axis = 1)
    values  = dot[np.arange(len(indices)), indices]
    return values,indices


res_np = parallel_value_np(A, B)
print(f'res_num = {res_np}')

%timeit res_num   = parallel_value_numba(A, B)

输出:

res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1, 3, 1, 1, 1, 1]))
The slowest run took 5.46 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.1 µs per loop