Numba 抱怨打字 - 但所有类型都已提供

Numba complains about typing - but all types are being provided

我在使用 Numba 打字时遇到问题 - 我阅读了手册,但最终还是碰壁了。

有问题的功能是一个更大项目的一部分 - 虽然它需要 运行 快速 - Python 列表是不可能的,因此我决定尝试 Numba。遗憾的是,该函数在 nopython=True 模式下失败,尽管事实上 - 根据我的理解 - 提供了所有类型。

代码如下:

from Numba import jit, njit, uint8, int64, typeof

@jit(uint8[:,:,:](int64))
def findWhite(cropped):
    h1 = int64(0)
    for i in cropped:
        for j in i:
            if np.sum(j) == 765:
                h1 = h1 + int64(1)
            else:
                pass
    return h1

另外,分别:

print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64

在这种情况下 'cropped' 是一个大的 uint8 3D C 矩阵(RGB tiff 文件理解 - PIL.Image)。有人可以向 Numba 新手解释我做错了什么吗?

你考虑过使用 Numpy 吗?这通常是 Python 列表和 Numba 之间的一个很好的中间值,例如:

h1 = (cropped.sum(axis=-1) == 765).sum()

h1 = (cropped == 255).all(axis=-1).sum()

您提供的示例代码不是有效的 Numba。您的签名也不正确,因为输入是一个 3D 数组,输出是一个整数,它可能应该是:

@njit(int64(uint8[:,:,:]))

像您一样循环遍历数组不是有效代码。您的代码的紧密翻译应该是这样的:

@njit(int64(uint8[:,:,:]))
def findWhite(cropped):

    h1 = int64(0)    
    ys, xs, n_bands = cropped.shape

    for i in range(ys):
        for j in range(xs):
            if cropped[i, j, :].sum() == 765:
                h1 += 1

    return h1

但这不是很快,也没有在我的机器上打败 Numpy。使用 Numba 可以明确地遍历数组中的每个元素,这已经快很多了:

@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):

    h1 = int64(0)    
    ys, xs, zs = cropped.shape

    for i in range(ys):
        for j in range(xs):

            incr = 1
            for k in range(zs):

                if cropped[i, j, k] != 255:
                    incr = 0
                    break

            h1 += incr

    return h1

对于 5000x5000x3 数组,这些是我的结果:

Numpy(h1 = (cropped == 255).all(axis=-1).sum()):

427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

找到白色:

612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

findWhite_numba:

31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numpy 方法的一个好处是它可以推广到任意数量的维度。