scipy 的 Numba jit

Numba jit with scipy

所以我想加速我在 numba 的帮助下编写的程序 jit。但是 jit 似乎与许多 scipy 函数不兼容,因为它们使用 try ... except ... jit 无法处理的结构(我是对这一点吗?)

我想到的一个相对简单的解决方案是复制我需要的scipy源代码并删除tryexcept部分(我已经知道它不会运行 进入错误,因此 try 部分将始终有效)

但是我不喜欢这个解决方案,我不确定它是否有效。

我的代码结构如下所示

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
        for idx in some_list:
            integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    except:
        fit_param=(0,0,0)
        ...

现在这会导致以下错误:

LoweringError: Failed at object (object mode backend)

我认为这是由于 jit 无法处理 try except(如果我只将 jit 放在 curve_fitintegrate.quad 部分并围绕我自己的 try except 结构工作)

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def integral(lower, upper):
    return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)

@jit
def fitting(x, y, pzero, max_fev)
    return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)


def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
        for idx in some_list:
            integrated = integral(lower, upper)
    except:
        fit_param=(0,0,0)
        ...

有没有一种方法可以将 jitscipy.integrate.quadcurve_fit 一起使用,而无需手动删除 scipy 中的所有 try except 结构代码?

它会加速代码吗?

Numba 只是 不是 用于加速代码的通用库。有 class 的问题可以用 numba 以更快的方式解决(特别是如果你有数组循环,数字运算)但其他一切要么(1)不受支持要么(2)只是稍微快一点甚至慢很多。

[...] would it even speed up the code?

SciPy 已经是一个高性能库,所以在大多数情况下,我希望 numba 表现更差(或者很少:稍微好一点)。您可能会做一些 profiling 以确定瓶颈是否真的存在于您 jit 编写的代码中,然后您可能会得到一些改进。但我怀疑瓶颈将出现在 SciPy 的编译代码中,并且该编译代码可能已经进行了高度优化(因此 真的 不太可能找到一个可以“仅" 与该代码竞争)。

Is there a way to use jit with scipy.integrate.quad and curve_fit without manually deleting all try except structures from the scipy code?

正如您正确假设的那样,tryexcept 目前根本不受 numba 支持。

2.6.1. Language

2.6.1.1. Constructs

Numba strives to support as much of the Python language as possible, but some language features are not available inside Numba-compiled functions. The following Python language features are not currently supported:

[...]

  • Exception handling (try .. except, try .. finally)

所以这里的答案是

现在 tryexcept 使用 numba。但是 numba 和 scipy 仍然不兼容。是的,Scipy 调用已编译的 C 和 Fortran,但它以 numba 无法处理的方式调用。

幸运的是,有 scipy 的替代方案可以很好地与 numba 配合使用!下面我使用 NumbaQuadpack and NumbaMinpack 进行一些类似于您的示例代码的曲线拟合和集成。免责声明:我把这些包放在一起。下面,我也在scipy.

中给出一个等价的实现

Scipy 实现比 Scipy 替代方案(NumbaQuadpack 和 NumbaMinpack)慢 ~18 倍

使用 Scipy 个备选方案(0.23 毫秒)

from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
    u = nb.carray(u_,(2,))
    args = nb.carray(args_,(200,))
    A, B = u
    x = args[:100]
    y = args[100:]
    for i in range(100):
        fvec[i] = fitfunction(x[i], A, B) - y[i] 
optimize_ptr = fitfunction_optimize.address

@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
    A = data[0]
    B = data[1]
    return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address

@nb.njit
def fast_function():  
    try:
        neqs = 100
        u_init = np.array([2.0,.8],np.float64)
        args = np.append(x,y)
        fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
        if not success:
            raise Exception

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
            if not success:
                raise Exception
    except:
        print('doing something else')
        
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)

使用 Scipy(4.4 毫秒)

import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit

np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

def function():
    try:
        p0 = (2.0,.8)
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
    except:
        print('do something else')

function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)