Numba / Numpy - 理解错误信息
Numba / Numpy - Understanding Error Message
我正在试验 Numba 以尝试加快我正在研究的联合查找算法。这是一些示例代码。当我试验一些样本数据时,我无法理解 Numba 似乎提出的类型投诉。
from numba import jit
import numpy as np
indices = np.arange(8806806, dtype=np.int64)
sizes = np.ones(8806806, dtype=np.int64)
connected_components = 8806806
@jit(npython=True)
def root(p: int) -> int:
while p != indices[p]:
indices[p] = indices[indices[p]]
p = indices[p]
return p
@jit(npython=True)
def connected( p: int, q: int) -> bool:
return root(p) == root(q)
@jit(npython=True)
def union( p: int, q: int) -> None:
root1 = root(p)
root2 = root(q)
if root1 == root2:
return
if (sizes[root1] < sizes[root2]):
indices[root1] = root2
sizes[root2] += sizes[root1]
else:
indices[root2] = root1
sizes[root1] += sizes[root2]
connected_components -= 1
@jit(nopython=True)
def process_values(arr):
for row in arr:
typed_arr = row.astype('int64')
for first, second in zip(arr, arr[1:]):
union(first, second)
process_values(
np.array(
[np.array([8018361, 4645960]),
np.array([1137555, 7763897]),
np.array([7532943, 2248813]),
np.array([5352737, 71466, 3590473, 5352738, 2712260])], dtype='object'))
我无法理解这个错误:
TypingError Traceback (most recent call last)
<ipython-input-45-62735e65f581> in <module>
44 np.array([1137555, 7763897]),
45 np.array([7532943, 2248813]),
---> 46 np.array([5352737, 71466, 3590473, 5352738, 2712260])], dtype='object'))
/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
/opt/conda/lib/python3.7/site-packages/numba/core/utils.py in reraise(tp, value, tb)
78 value = tp()
79 if value.__traceback__ is not tb:
---> 80 raise value.with_traceback(tb)
81 raise value
82
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
[1] During: typing of argument at <ipython-input-45-62735e65f581> (36)
File "<ipython-input-45-62735e65f581>", line 36:
def process_values(arr):
for row in arr:
^
这与process_values
采用不规则形状的数组有什么关系吗?任何指针?谢谢!
问题是 Numba 不接受 dtype 'object' 的数组。您似乎将数组放在数组中,您将不得不在列表中使用列表。在 Numba 中寻找 typed.List
class,https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#typed-list
或者,您可以使用笨拙的数组:https://github.com/scikit-hep/awkward-1.0
我正在试验 Numba 以尝试加快我正在研究的联合查找算法。这是一些示例代码。当我试验一些样本数据时,我无法理解 Numba 似乎提出的类型投诉。
from numba import jit
import numpy as np
indices = np.arange(8806806, dtype=np.int64)
sizes = np.ones(8806806, dtype=np.int64)
connected_components = 8806806
@jit(npython=True)
def root(p: int) -> int:
while p != indices[p]:
indices[p] = indices[indices[p]]
p = indices[p]
return p
@jit(npython=True)
def connected( p: int, q: int) -> bool:
return root(p) == root(q)
@jit(npython=True)
def union( p: int, q: int) -> None:
root1 = root(p)
root2 = root(q)
if root1 == root2:
return
if (sizes[root1] < sizes[root2]):
indices[root1] = root2
sizes[root2] += sizes[root1]
else:
indices[root2] = root1
sizes[root1] += sizes[root2]
connected_components -= 1
@jit(nopython=True)
def process_values(arr):
for row in arr:
typed_arr = row.astype('int64')
for first, second in zip(arr, arr[1:]):
union(first, second)
process_values(
np.array(
[np.array([8018361, 4645960]),
np.array([1137555, 7763897]),
np.array([7532943, 2248813]),
np.array([5352737, 71466, 3590473, 5352738, 2712260])], dtype='object'))
我无法理解这个错误:
TypingError Traceback (most recent call last)
<ipython-input-45-62735e65f581> in <module>
44 np.array([1137555, 7763897]),
45 np.array([7532943, 2248813]),
---> 46 np.array([5352737, 71466, 3590473, 5352738, 2712260])], dtype='object'))
/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
/opt/conda/lib/python3.7/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
/opt/conda/lib/python3.7/site-packages/numba/core/utils.py in reraise(tp, value, tb)
78 value = tp()
79 if value.__traceback__ is not tb:
---> 80 raise value.with_traceback(tb)
81 raise value
82
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
[1] During: typing of argument at <ipython-input-45-62735e65f581> (36)
File "<ipython-input-45-62735e65f581>", line 36:
def process_values(arr):
for row in arr:
^
这与process_values
采用不规则形状的数组有什么关系吗?任何指针?谢谢!
问题是 Numba 不接受 dtype 'object' 的数组。您似乎将数组放在数组中,您将不得不在列表中使用列表。在 Numba 中寻找 typed.List
class,https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#typed-list
或者,您可以使用笨拙的数组:https://github.com/scikit-hep/awkward-1.0