Numba 抱怨打字 - 但所有类型都已提供
Numba complains about typing - but all types are being provided
我在使用 Numba 打字时遇到问题 - 我阅读了手册,但最终还是碰壁了。
有问题的功能是一个更大项目的一部分 - 虽然它需要 运行 快速 - Python 列表是不可能的,因此我决定尝试 Numba。遗憾的是,该函数在 nopython=True 模式下失败,尽管事实上 - 根据我的理解 - 提供了所有类型。
代码如下:
from Numba import jit, njit, uint8, int64, typeof
@jit(uint8[:,:,:](int64))
def findWhite(cropped):
h1 = int64(0)
for i in cropped:
for j in i:
if np.sum(j) == 765:
h1 = h1 + int64(1)
else:
pass
return h1
另外,分别:
print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64
在这种情况下 'cropped' 是一个大的 uint8 3D C 矩阵(RGB tiff 文件理解 - PIL.Image)。有人可以向 Numba 新手解释我做错了什么吗?
你考虑过使用 Numpy 吗?这通常是 Python 列表和 Numba 之间的一个很好的中间值,例如:
h1 = (cropped.sum(axis=-1) == 765).sum()
或
h1 = (cropped == 255).all(axis=-1).sum()
您提供的示例代码不是有效的 Numba。您的签名也不正确,因为输入是一个 3D 数组,输出是一个整数,它可能应该是:
@njit(int64(uint8[:,:,:]))
像您一样循环遍历数组不是有效代码。您的代码的紧密翻译应该是这样的:
@njit(int64(uint8[:,:,:]))
def findWhite(cropped):
h1 = int64(0)
ys, xs, n_bands = cropped.shape
for i in range(ys):
for j in range(xs):
if cropped[i, j, :].sum() == 765:
h1 += 1
return h1
但这不是很快,也没有在我的机器上打败 Numpy。使用 Numba 可以明确地遍历数组中的每个元素,这已经快很多了:
@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):
h1 = int64(0)
ys, xs, zs = cropped.shape
for i in range(ys):
for j in range(xs):
incr = 1
for k in range(zs):
if cropped[i, j, k] != 255:
incr = 0
break
h1 += incr
return h1
对于 5000x5000x3 数组,这些是我的结果:
Numpy(h1 = (cropped == 255).all(axis=-1).sum()
):
427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
找到白色:
612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite_numba:
31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numpy 方法的一个好处是它可以推广到任意数量的维度。
我在使用 Numba 打字时遇到问题 - 我阅读了手册,但最终还是碰壁了。
有问题的功能是一个更大项目的一部分 - 虽然它需要 运行 快速 - Python 列表是不可能的,因此我决定尝试 Numba。遗憾的是,该函数在 nopython=True 模式下失败,尽管事实上 - 根据我的理解 - 提供了所有类型。
代码如下:
from Numba import jit, njit, uint8, int64, typeof
@jit(uint8[:,:,:](int64))
def findWhite(cropped):
h1 = int64(0)
for i in cropped:
for j in i:
if np.sum(j) == 765:
h1 = h1 + int64(1)
else:
pass
return h1
另外,分别:
print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64
在这种情况下 'cropped' 是一个大的 uint8 3D C 矩阵(RGB tiff 文件理解 - PIL.Image)。有人可以向 Numba 新手解释我做错了什么吗?
你考虑过使用 Numpy 吗?这通常是 Python 列表和 Numba 之间的一个很好的中间值,例如:
h1 = (cropped.sum(axis=-1) == 765).sum()
或
h1 = (cropped == 255).all(axis=-1).sum()
您提供的示例代码不是有效的 Numba。您的签名也不正确,因为输入是一个 3D 数组,输出是一个整数,它可能应该是:
@njit(int64(uint8[:,:,:]))
像您一样循环遍历数组不是有效代码。您的代码的紧密翻译应该是这样的:
@njit(int64(uint8[:,:,:]))
def findWhite(cropped):
h1 = int64(0)
ys, xs, n_bands = cropped.shape
for i in range(ys):
for j in range(xs):
if cropped[i, j, :].sum() == 765:
h1 += 1
return h1
但这不是很快,也没有在我的机器上打败 Numpy。使用 Numba 可以明确地遍历数组中的每个元素,这已经快很多了:
@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):
h1 = int64(0)
ys, xs, zs = cropped.shape
for i in range(ys):
for j in range(xs):
incr = 1
for k in range(zs):
if cropped[i, j, k] != 255:
incr = 0
break
h1 += incr
return h1
对于 5000x5000x3 数组,这些是我的结果:
Numpy(h1 = (cropped == 255).all(axis=-1).sum()
):
427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
找到白色:
612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite_numba:
31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numpy 方法的一个好处是它可以推广到任意数量的维度。