Numba nopython 模式不能接受二维布尔索引
Numba nopython mode cannot accept 2-D boolean indexing
我正在尝试使用 numba
加速代码(目前我正在使用 numba 0.45.1
)并遇到布尔索引问题。代码如下:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
@njit
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
new_arr = func(n_arr)
一旦我 运行 代码,我就会收到以下消息
TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)
注意最后一行的(29)
对应的是第29行,也就是result[idx] = 10.1
,这一行我尝试给索引为idx
的result赋值,一个2 -D 布尔索引。
我想解释一下,必须在 @njit
中包含该语句 result[idx] = 10.1
。尽管我想在 @njit
中排除这条语句,但我做不到,因为这一行恰好位于我正在处理的代码的中间。
如果我坚持要在 @njit
中包含赋值语句 result[idx] = 10.1
,究竟需要更改什么才能使其生效?如果可能的话,我希望看到一些代码示例,其中涉及 @njit
中的二维布尔索引,可以是 运行.
谢谢
Numba 当前不支持使用二维数组进行花式索引。参见:
https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access
但是,您可以通过明确地使用 for 循环重写您的函数而不是依赖广播来获得等效的行为:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
@njit
def func2(arr):
M = arr[-1]
N = arr.shape[0]
result = np.zeros((M, N))
for i in range(M):
for j in range(N):
if i < arr[j] - 2:
result[i, j] = 10.1
return result
new_arr = func(n_arr)
new_arr2 = func2(n_arr)
print(np.allclose(new_arr, new_arr2)) # True
在我的机器上,使用您提供的示例输入,func2
比 func
快 3.5 倍。
我正在尝试使用 numba
加速代码(目前我正在使用 numba 0.45.1
)并遇到布尔索引问题。代码如下:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
@njit
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
new_arr = func(n_arr)
一旦我 运行 代码,我就会收到以下消息
TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)
注意最后一行的(29)
对应的是第29行,也就是result[idx] = 10.1
,这一行我尝试给索引为idx
的result赋值,一个2 -D 布尔索引。
我想解释一下,必须在 @njit
中包含该语句 result[idx] = 10.1
。尽管我想在 @njit
中排除这条语句,但我做不到,因为这一行恰好位于我正在处理的代码的中间。
如果我坚持要在 @njit
中包含赋值语句 result[idx] = 10.1
,究竟需要更改什么才能使其生效?如果可能的话,我希望看到一些代码示例,其中涉及 @njit
中的二维布尔索引,可以是 运行.
谢谢
Numba 当前不支持使用二维数组进行花式索引。参见:
https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access
但是,您可以通过明确地使用 for 循环重写您的函数而不是依赖广播来获得等效的行为:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
@njit
def func2(arr):
M = arr[-1]
N = arr.shape[0]
result = np.zeros((M, N))
for i in range(M):
for j in range(N):
if i < arr[j] - 2:
result[i, j] = 10.1
return result
new_arr = func(n_arr)
new_arr2 = func2(n_arr)
print(np.allclose(new_arr, new_arr2)) # True
在我的机器上,使用您提供的示例输入,func2
比 func
快 3.5 倍。