如何在 numba.njit() 中使用 numpy.array 的索引?

how to use index of numpy.array in numba.njit()?

如何在numba.njit()中使用numpy.array的索引?下面,如果使用numba.njit,代码会报错退出。我发现错误是由于“b = a[idx]”引起的。但实际上,那应该是python才对。如何在numba中纠正它?谢谢

@numba.njit()
def test(a):
    idx = np.where(a>5)
    b   = a[idx]
    return b

a = np.linspace(0,15,16).reshape([4,4])
b = test(a)

查看此文档

http://numba.pydata.org/numba-doc/0.15.1/arrays.html

b = 测试(a)

尝试像

那样改变这里的变量

k=test(a),正如我认为 test(a) = b 那样意味着 b=b

试一试是否有效...

docs也支持高级索引的一个子集:只允许一个高级索引,而且它必须是一维数组

如果你 运行 你的代码没有 numba,你可以看到结果是一个一维数组:

>>> a[np.where(a > 5)]
array([ 6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15.])

所以可以直接对一维数组进行操作:

@nb.njit()
def test(a):
    a = a.ravel()
    idx = np.where(a > 5)
    b = a[idx]
    return b

或者更简单:

@nb.njit()
def test(a):
    a = a.ravel()
    return a[a > 5]