如何在 numba 的“@guvectorize”中调用“@guvectorize”?

How call a `@guvectorize` inside a `@guvectorize` in numba?

我正在尝试在 @guvectorize 中调用 @guvectorize,但出现错误提示:

Untyped global name 'regNL_nb': cannot determine Numba type of <class 'numpy.ufunc'>

File "di.py", line 12:
def H2Delay_nb(S1, S2, R2):
    H1 = regNL_nb(S1, S2)
    ^

这是一个 MRE:

import numpy as np
from numba import guvectorize, float64, int64, njit, cuda, jit

@guvectorize(["float64[:], float64[:], float64[:]"], '(n),(n)->(n)')
def regNL_nb(S1, S2, h2):
    for i in range(len(S1)):
        h2[i] = S1[i] + S2[i]

@guvectorize(["float64[:], float64[:],  float64[:]"], '(n),(n)->(n)',nopython=True)
def H2Delay_nb(S1, S2, R2):
    H1 = regNL_nb(S1, S2)
    H2 = regNL_nb(S1, S2,)
    for i in range(len(S1)):
        R2[i] =  H1[i] + H2[i]


S1 = np.array([1,2,3,4,5,6,7,8,9])
S2 = np.array([1,2,3,4,5,6,7,8,9])
H2 = H2Delay_nb(S1, S2)
print(H2)

我不知道如何告诉 numba 函数 regNL_nb 是一个 guvectorize 函数。

@guvectorize(["float64[:], float64[:],  float64[:]"], '(n),(n)->(n)',nopython=True)
def H2Delay_nb(S1, S2, R2):
    H1 = regNL_nb(S1, S2)
    H2 = regNL_nb(S1, S2,)
    for i in range(len(S1)):
        R2[i] =  H1[i] + H2[i]

通过使用参数 nopython = True 可以停用对象模式,因此 Numba 无法将所有值作为 Python 对象处理(参考:https://numba.pydata.org/numba-doc/latest/glossary.html#term-object-mode

一般来说,如果您使用 nopython = True,Panda、Numba 或其他函数调用是不可能的。只有有限数量的库可以与 Numba Jit 一起使用(nopython)。 完整列表可在此处找到:https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html.

因此,除了禁用 nopython,您尝试做的事情是不可能的,即:

@guvectorize(["float64[:], float64[:],  float64[:]"], '(n),(n)->(n)',nopython=False)
    def H2Delay_nb(S1, S2, R2):
        H1 = regNL_nb(S1, S2)
        H2 = regNL_nb(S1, S2,)
        for i in range(len(S1)):
            R2[i] =  H1[i] + H2[i]

采用这种方法,程序输出正确的值,即 [ 4. 8. 12. 16. 20. 24. 28. 32. 36.] for H2.

我还发现了另一个处理熟悉问题的 Whosebug 问题:. Credits where credit is due: Kevin K. 在提到的线程中建议您应使用 'simpler' 数据类型 -最常见于 CPython。除此之外,在这一点上我完全同意他的看法,我不知道在 nopython 模式激活时有任何可能的解决方案。


来源:

我的回答仅适用于如果您愿意用@njit 替换@guvectorize 的情况,它将是完全相同的代码,速度相同,只是要使用的语法更长一些。

在 nopython 模式下接受其他 guvectorized 函数中的 @guvectorize-ed 函数可能存在一些问题。

但是 Numba 接受其他 njited 中非常好的常规 @njit-ed 函数。所以你可以重写你的函数以使用@njit,你的函数签名将与外部世界的@guvectorize-ed 保持相同。 @njit 版本只需要在函数内部额外使用 np.empty_like(...) + return。

提醒您 - 所有@njit-ed 函数始终启用 nopython 模式,因此您的 njited 代码将与 guvectorize+nopython 一样快。

我还提供了 CUDA 解决方案作为第二个代码片段。

您也可以将@njited 设为仅内部辅助函数,但外部您可能仍然可以使用@guvectorize-ed。此外,如果您想要通用功能(接受任何输入),只需从 njited 定义中删除签名 'f8[:](f8[:], f8[:])',签名将在调用时自动解析。

最终代码如下所示:

Try it online!

import numpy as np
from numba import guvectorize, float64, int64, njit, cuda, jit

@njit('f8[:](f8[:], f8[:])', cache = True)
def regNL_nb(S1, S2):
    h2 = np.empty_like(S1)
    for i in range(len(S1)):
        h2[i] = S1[i] + S2[i]
    return h2
        
@njit('f8[:](f8[:], f8[:])', cache = True)
def H2Delay_nb(S1, S2):
    H1 = regNL_nb(S1, S2)
    H2 = regNL_nb(S1, S2)
    R2 = np.empty_like(H1)
    for i in range(len(S1)):
        R2[i] =  H1[i] + H2[i]
    return R2

S1 = np.array([1,2,3,4,5,6,7,8,9], dtype = np.float64)
S2 = np.array([1,2,3,4,5,6,7,8,9], dtype = np.float64)
H2 = H2Delay_nb(S1, S2)
print(H2)

输出:

[ 4.  8. 12. 16. 20. 24. 28. 32. 36.]

相同代码的 CUDA 变体,如果您想自动创建和 return 结果数组,它需要额外的函数包装器,因为 CUDA 代码函数不允许具有 return 值:

import numpy as np
from numba import guvectorize, float64, int64, njit, cuda, jit

@cuda.jit('void(f8[:], f8[:], f8[:])', cache = True)
def regNL_nb_cu(S1, S2, h2):
    for i in range(len(S1)):
        h2[i] = S1[i] + S2[i]
        
@njit('f8[:](f8[:], f8[:])', cache = True)
def regNL_nb(S1, S2):
    h2 = np.empty_like(S1)
    regNL_nb_cu(S1, S2, h2)
    return h2
        
@cuda.jit('void(f8[:], f8[:], f8[:])', cache = True)
def H2Delay_nb_cu(S1, S2, R2):
    H1 = regNL_nb(S1, S2)
    H2 = regNL_nb(S1, S2)
    for i in range(len(S1)):
        R2[i] =  H1[i] + H2[i]
        
@njit('f8[:](f8[:], f8[:])', cache = True)
def H2Delay_nb(S1, S2):
    R2 = np.empty_like(S1)
    H2Delay_nb_cu(S1, S2, R2)
    return R2

S1 = np.array([1,2,3,4,5,6,7,8,9], dtype = np.float64)
S2 = np.array([1,2,3,4,5,6,7,8,9], dtype = np.float64)
H2 = H2Delay_nb(S1, S2)
print(H2)