如何使用嵌套 for 循环解决 numba 降低错误?
How to solve numba lowering error with nested for loops?
我想找到函数对数组乘积的最小值。我从一个简单的嵌套 for 循环和比较实现开始。由于 numba 帮助我在代码的许多其他地方获得了高加速,我发现我只需要在我的简单网格搜索中添加装饰器。等效示例是以下代码:
import numpy as np
from numba import njit
@njit()
def test():
xs = np.arange(0, 1, 0.1)
ys = np.arange(0, 1, 0.1)
res = [0]
smallest_value = np.inf
for x in xs:
for y in ys:
value = x*x + y*y
if value < smallest_value:
smallest_value = value
res = [x, y]
return res, smallest_value
if __name__ == "__main__":
print(test())
当我删除装饰器时它工作正常,但是有了它,我有以下错误:
File "numba_error.py", line 20:
def test():
<source elided>
return res, smallest_value
^
During: lowering "res.2 = res" at numba_error.py (20)
我回答了以下问题:How to Solve Numba Lowering error?但我没有模块可供参考。
这个小修改对我有用:使用元组而不是列表。
import numpy as np
from numba import njit
@njit()
def test():
xs = np.arange(0, 1, 0.1)
ys = np.arange(0, 1, 0.1)
res = (0,0)
smallest_value = np.inf
for x in xs:
for y in ys:
value = x*x + y*y
if value < smallest_value:
smallest_value = value
res = (x,y)
return res, smallest_value
if __name__ == "__main__":
print(test())
我想找到函数对数组乘积的最小值。我从一个简单的嵌套 for 循环和比较实现开始。由于 numba 帮助我在代码的许多其他地方获得了高加速,我发现我只需要在我的简单网格搜索中添加装饰器。等效示例是以下代码:
import numpy as np
from numba import njit
@njit()
def test():
xs = np.arange(0, 1, 0.1)
ys = np.arange(0, 1, 0.1)
res = [0]
smallest_value = np.inf
for x in xs:
for y in ys:
value = x*x + y*y
if value < smallest_value:
smallest_value = value
res = [x, y]
return res, smallest_value
if __name__ == "__main__":
print(test())
当我删除装饰器时它工作正常,但是有了它,我有以下错误:
File "numba_error.py", line 20:
def test():
<source elided>
return res, smallest_value
^
During: lowering "res.2 = res" at numba_error.py (20)
我回答了以下问题:How to Solve Numba Lowering error?但我没有模块可供参考。
这个小修改对我有用:使用元组而不是列表。
import numpy as np
from numba import njit
@njit()
def test():
xs = np.arange(0, 1, 0.1)
ys = np.arange(0, 1, 0.1)
res = (0,0)
smallest_value = np.inf
for x in xs:
for y in ys:
value = x*x + y*y
if value < smallest_value:
smallest_value = value
res = (x,y)
return res, smallest_value
if __name__ == "__main__":
print(test())