np.mean 的 Numba 失败
Numba failure with np.mean
出于某种原因,当我向 np.mean 添加轴参数时,numba 失败了。例如,这给出了一个错误 -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a,-1)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
TypingError: Invalid use of Function(<function mean at 0x000002949B28E1E0>) with argument(s) of type(s): (array(int32, 2d, C), Literal[int](1))
* parameterized
In definition 0:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
In definition 1:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function mean at 0x000002949B28E1E0>)
[2] During: typing of call at C:/Users/U374235/test.py (11)
但是,这非常有效 -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
numba 不支持 np.mean() 的参数(包括未包含的“轴”参数)。
您可以执行以下操作以获得类似的结果:
import numpy as np
from numba import jit, prange
a = np.array([[0, 1, 2], [3, 4, 5]])
res_numpy = np.mean(a, -1)
@jit(parallel=True)
def mean_numba(a):
res = []
for i in prange(a.shape[0]):
res.append(a[i, :].mean())
return np.array(res)
np.array_equal(res_numpy, mean_numba(a))
相关 github 问题:https://github.com/numba/numba/issues/1269
出于某种原因,当我向 np.mean 添加轴参数时,numba 失败了。例如,这给出了一个错误 -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a,-1)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
TypingError: Invalid use of Function(<function mean at 0x000002949B28E1E0>) with argument(s) of type(s): (array(int32, 2d, C), Literal[int](1))
* parameterized
In definition 0:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
In definition 1:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function mean at 0x000002949B28E1E0>)
[2] During: typing of call at C:/Users/U374235/test.py (11)
但是,这非常有效 -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
numba 不支持 np.mean() 的参数(包括未包含的“轴”参数)。
您可以执行以下操作以获得类似的结果:
import numpy as np
from numba import jit, prange
a = np.array([[0, 1, 2], [3, 4, 5]])
res_numpy = np.mean(a, -1)
@jit(parallel=True)
def mean_numba(a):
res = []
for i in prange(a.shape[0]):
res.append(a[i, :].mean())
return np.array(res)
np.array_equal(res_numpy, mean_numba(a))
相关 github 问题:https://github.com/numba/numba/issues/1269