如何将 numba 与 functools.reduce() 一起使用
How to use numba together with functools.reduce()
我有以下代码,我尝试使用 numba
、functools.reduce()
和 mul
进行并行循环:
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import jit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul, gen[i], initializer=None)
return results
mtp(gen)
但这给我一个错误:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-503-cd6ef880fd4a> in <module>
10 results[i] = reduce(mul, gen[i], initializer=None)
11 return results
---> 12 mtp(gen)
~\Anaconda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~\Anaconda3\lib\site-packages\numba\dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~\Anaconda3\lib\site-packages\numba\six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function reduce>) with argument(s) of type(s): (Function(<built-in function mul>), array(int32, 1d, C), initializer=none)
* parameterized
In definition 0:
AssertionError:
raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
In definition 1:
AssertionError:
raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function reduce>)
[2] During: typing of call at <ipython-input-503-cd6ef880fd4a> (10)
File "<ipython-input-503-cd6ef880fd4a>", line 10:
def mtp(gen):
<source elided>
for i in prange(gen.shape[0]):
results[i] = reduce(mul, gen[i], initializer=None)
^
我不确定我哪里做错了。谁能指出我正确的方向?非常感谢。
您可以在 numba jitted 函数中使用 np.prod:
n = 3
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
flat = np.ravel(arr).tolist()
gen = [list(a) for a in product(flat, repeat=n)]
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(len(gen))
for i in prange(len(gen)):
results[i] = np.prod(gen[i])
return results
或者,您可以使用 reduce 如下(感谢@stuartarchibald 指出这一点),尽管并行化在下面不起作用(至少从 numba 0.48 开始):
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import njit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@njit
def mul_wrapper(x, y):
return mul(x, y)
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul_wrapper, gen[i], None)
return results
print(mtp(gen))
或者,因为 Numba 内部有一些魔法可以发现将转义函数并编译它们的闭包。 (再次感谢@stuartarchibald),你可以在下面这样做:
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
def op(x, y):
return mul(x, y)
for i in prange(gen.shape[0]):
results[i] = reduce(op, gen[i], None)
return results
但是,从 numba 0.48 开始,并行在这里不起作用。
注意,核心开发团队成员推荐的方法是采用第一个使用 np.prod
的解决方案。它可以与并行标志一起使用,并且具有更直接的实现。
我有以下代码,我尝试使用 numba
、functools.reduce()
和 mul
进行并行循环:
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import jit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul, gen[i], initializer=None)
return results
mtp(gen)
但这给我一个错误:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-503-cd6ef880fd4a> in <module>
10 results[i] = reduce(mul, gen[i], initializer=None)
11 return results
---> 12 mtp(gen)
~\Anaconda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~\Anaconda3\lib\site-packages\numba\dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~\Anaconda3\lib\site-packages\numba\six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function reduce>) with argument(s) of type(s): (Function(<built-in function mul>), array(int32, 1d, C), initializer=none)
* parameterized
In definition 0:
AssertionError:
raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
In definition 1:
AssertionError:
raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function reduce>)
[2] During: typing of call at <ipython-input-503-cd6ef880fd4a> (10)
File "<ipython-input-503-cd6ef880fd4a>", line 10:
def mtp(gen):
<source elided>
for i in prange(gen.shape[0]):
results[i] = reduce(mul, gen[i], initializer=None)
^
我不确定我哪里做错了。谁能指出我正确的方向?非常感谢。
您可以在 numba jitted 函数中使用 np.prod:
n = 3
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
flat = np.ravel(arr).tolist()
gen = [list(a) for a in product(flat, repeat=n)]
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(len(gen))
for i in prange(len(gen)):
results[i] = np.prod(gen[i])
return results
或者,您可以使用 reduce 如下(感谢@stuartarchibald 指出这一点),尽管并行化在下面不起作用(至少从 numba 0.48 开始):
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import njit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@njit
def mul_wrapper(x, y):
return mul(x, y)
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul_wrapper, gen[i], None)
return results
print(mtp(gen))
或者,因为 Numba 内部有一些魔法可以发现将转义函数并编译它们的闭包。 (再次感谢@stuartarchibald),你可以在下面这样做:
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
def op(x, y):
return mul(x, y)
for i in prange(gen.shape[0]):
results[i] = reduce(op, gen[i], None)
return results
但是,从 numba 0.48 开始,并行在这里不起作用。
注意,核心开发团队成员推荐的方法是采用第一个使用 np.prod
的解决方案。它可以与并行标志一起使用,并且具有更直接的实现。