如何使用嵌套 for 循环解决 numba 降低错误?

How to solve numba lowering error with nested for loops?

我想找到函数对数组乘积的最小值。我从一个简单的嵌套 for 循环和比较实现开始。由于 numba 帮助我在代码的许多其他地方获得了高加速,我发现我只需要在我的简单网格搜索中添加装饰器。等效示例是以下代码:

import numpy as np
from numba import njit


@njit()
def test():

    xs = np.arange(0, 1, 0.1)
    ys = np.arange(0, 1, 0.1)

    res = [0]
    smallest_value = np.inf
    for x in xs:
        for y in ys:
            value = x*x + y*y
            if value < smallest_value:
                smallest_value = value
                res = [x, y]

    return res, smallest_value


if __name__ == "__main__":

    print(test())

当我删除装饰器时它工作正常,但是有了它,我有以下错误:

File "numba_error.py", line 20:
def test():
    <source elided>

    return res, smallest_value
    ^

During: lowering "res.2 = res" at numba_error.py (20)

我回答了以下问题:How to Solve Numba Lowering error?但我没有模块可供参考。

这个小修改对我有用:使用元组而不是列表。

import numpy as np
from numba import njit


@njit()
def test():

    xs = np.arange(0, 1, 0.1)
    ys = np.arange(0, 1, 0.1)

    res = (0,0)
    smallest_value = np.inf
    for x in xs:
        for y in ys:
            value = x*x + y*y
            if value < smallest_value:
                smallest_value = value
                res = (x,y)

    return res, smallest_value


if __name__ == "__main__":

    print(test())