Numba / Numpy - 理解错误信息

Numba / Numpy - Understanding Error Message

我正在试验 Numba 以尝试加快我正在研究的联合查找算法。这是一些示例代码。当我试验一些样本数据时,我无法理解 Numba 似乎提出的类型投诉。

from numba import jit
import numpy as np

indices = np.arange(8806806, dtype=np.int64)
sizes = np.ones(8806806, dtype=np.int64)
connected_components = 8806806

@jit(npython=True)
def root(p: int) -> int:
    while p != indices[p]:
        indices[p] = indices[indices[p]]
        p = indices[p]
    return p

@jit(npython=True)
def connected( p: int, q: int) -> bool: 
    return root(p) == root(q)

@jit(npython=True)
def union( p: int, q: int) -> None:
    root1 = root(p)
    root2 = root(q)
    if root1 == root2:
        return

    if (sizes[root1] < sizes[root2]):
        indices[root1] = root2
        sizes[root2] += sizes[root1]
    else:
        indices[root2] = root1
        sizes[root1] += sizes[root2]

    connected_components -= 1
    
@jit(nopython=True)
def process_values(arr):
    for row in arr:
        typed_arr = row.astype('int64')
        for first, second in zip(arr, arr[1:]):
            union(first, second)
            
process_values(
       np.array(
           [np.array([8018361, 4645960]),
            np.array([1137555, 7763897]),
            np.array([7532943, 2248813]),
            np.array([5352737,   71466, 3590473, 5352738, 2712260])], dtype='object'))            

我无法理解这个错误:

TypingError                               Traceback (most recent call last)
<ipython-input-45-62735e65f581> in <module>
     44             np.array([1137555, 7763897]),
     45             np.array([7532943, 2248813]),
---> 46             np.array([5352737,   71466, 3590473, 5352738, 2712260])], dtype='object'))            

/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

/opt/conda/lib/python3.7/site-packages/numba/core/utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
[1] During: typing of argument at <ipython-input-45-62735e65f581> (36)

File "<ipython-input-45-62735e65f581>", line 36:
def process_values(arr):
    for row in arr:
    ^

这与process_values采用不规则形状的数组有什么关系吗?任何指针?谢谢!

问题是 Numba 不接受 dtype 'object' 的数组。您似乎将数组放在数组中,您将不得不在列表中使用列表。在 Numba 中寻找 typed.List class,https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#typed-list

或者,您可以使用笨拙的数组:https://github.com/scikit-hep/awkward-1.0