Numba - 使用可变索引访问 Numpy 数组

Numba - Accessing Numpy Array with variable index

我正在研究使用 numba 来加速迭代计算,特别是在计算有时依赖于先前计算的结果并且因此向量化并不总是适用的情况下。我发现缺少的一件事是它似乎不允许数据帧。尽管我认为没问题,您可以传递一个 2D numpy 数组和一个 numpy 列名数组,并且我试图实现一个函数来通过列名而不是索引来引用值。这是我目前的代码。

from numba import jit
import numpy as np
@jit(nopython=True)
def get_index(cols,col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
@jit(nopython=True)
def get_element(ndarr: np.ndarray,cols:np.ndarray,row:np.int8,name:str):
    ind = get_index(cols,name)
    print(row)
    print(ind)
    print(ndarr[0][0])
    #print(ndarr[row][ind])
get_element(np.array([['HI'],['BYE'],['HISAHASDG']]),np.array(['COLUMN_1']),0,"COLUMN_1")

我有 get_index,我已经对其进行了独立测试并且可以正常工作。这基本上是 np.where 的一个实现,我想知道这是否会导致我的错误。因此,在打印注释掉后,这段代码现在可以运行了。它按预期打印出 0、0,然后是“HI”。所以理论上所有被注释掉的行都应该打印“HI”,就像上一行的打印一样,因为 row 和 ind 都是 0。但是当我取消注释时,我得到以下信息:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<timed exec> in <module>

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array([unichr x 50], 1d, C), OptionalType(int64) i.e. the type 'int64 or None')
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
In definition 10:
    All templates rejected with literals.
In definition 11:
    All templates rejected without literals.
In definition 12:
    TypeError: unsupported array index type OptionalType(int64) i.e. the type 'int64 or None' in [OptionalType(int64)]
    raised from /sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/typing/arraydecl.py:69
In definition 13:
    TypeError: unsupported array index type OptionalType(int64) i.e. the type 'int64 or None' in [OptionalType(int64)]
    raised from /sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/typing/arraydecl.py:69
In definition 14:
    All templates rejected with literals.
In definition 15:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <timed exec> (15)

File "<timed exec>", line 15:
<source missing, REPL/exec in use?>

有什么我想念的吗?我检查了 row 和 ind 的类型,它们确实是 int 类型。为什么 numba 不让我使用 int 变量进行子集化?谢谢

numba 真是太聪明了!考虑一下当您将不在 cols 中的 col 传递给 get_index 时会发生什么。 cols[i] == col 永远不会为真,循环将退出,并且由于函数末尾没有包罗万象的 return,return 值将为 None.

numba 因此正确地推断出 get_index 的 return 类型是 OptionalType(int64),即可以是 int64 或 None 的值。但是 None 不是索引的有效类型,因此您不能使用可能是 None 的值来索引数组。

您可以通过在末尾添加一个包罗万象的 return 来解决此问题。

@jit(nopython=True)
def get_index(cols, col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
    return -1

当然,在这种情况下这可能不是您想要的行为;最好提出一个异常,numba 也能正确处理。

@jit(nopython=True)
def get_index(cols, col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
    raise IndexError('list index out of range')