有没有更简洁的方法来实现多个模型的曲线拟合?

Is there a cleaner way to achieve curve fitting with multiple models?

在我的项目中,我预定义了多个函数族来拟合曲线。看最简单的:

def polyfit3(x, b0, b1, b2, b3):
    return b0+b1*x+b2*x**2+b3*x**3

def polyfit2(x, b0, b1, b2):
    return b0+b1*x+b2*x**2

def polyfit1(x, b0, b1):
    return b0+b1*x

注意: 我知道在这种特殊情况下 np.polyfit 会是更好的选择

使拟合看起来像这样的(更简单的)函数:

from scipy.optimize import curve_fit
try:
    from lmfit import Model
    _has_lmfit = True
except ImportError:
    _has_lmfit = False

def f(x, y, order=3):
    if _has_lmfit:
        if order == 3:
            fitModel = Model(polyfit3)
            params = fitModel.make_params(b0=0, b1=1, b2=1, b3=1)
            result = fitModel.fit(y, x=x, params=params)
        elif order == 2:
            fitModel = Model(polyfit2)
            params = fitModel.make_params(b0=0, b1=1, b2=1)
            result = fitModel.fit(y, x=x, params=params)
        elif order == 1:
            fitModel = Model(polyfit1)
            params = fitModel.make_params(b0=0, b1=1)
            result = fitModel.fit(y, x=x, params=params)
        else:
            raise ValueError('Order is out of range, please select from [1, 3].')
    else:
        if order == 3:
            popt, pcov = curve_fit(polyfit3, x, y)
            _function = polyfit3
        elif order == 2:
            popt, pcov = curve_fit(polyfit2, x, y)
            _function = polyfit2
        elif order == 1:
            popt, pcov = curve_fit(polyfit1, x, y)
            _function = polyfit1
        else:
            raise ValueError('Order is out of range, please select from [1, 3].')
    # more code there.. mostly working with the optimized parameters, plotting, etc.

我的问题是这很快就会变得非常丑陋,我一遍又一遍地重复自己。有没有更好的方法?

编辑:

我试过这个:

def poly_fit(x, *args):
    return sum(b*x**i for i, b in enumerate(args))

...

fitModel = Model(poly_fit)
fitModel.make_params(**{f'b{i}': 1 for i in range(order+1)})

但不幸的是 lmfit 抛出错误:

ValueError: varargs '*args' is not supported

我通过为您的 polyfit 函数创建全局配置来重写您的代码。这是 if 的更 pythonic 版本。

polyfits = {
    1: {
        'f': polyfit1,
        'params': ['b0', 'b1'],
        'vals'  : [  0,    1], 
    },
    2: {
        'f': polyfit2,
        'params': ['b0', 'b1', 'b2'],
        'vals'  : [   0,    1,   1,], 
    },
    3: {
        'f': polyfit3,
        'params': ['b0', 'b1', 'b2', 'b3'],
        'vals'  : [   0,    1,    1,    1], 
    },

}

def f(x, y, order=3):
    if order not in polyfits.keys():
        raise ValueError('Order is out of range, please select from {}.'.format(','.join(map(str, polyfits.keys()))))
    _function = polyfits[order]['f']
    if _has_lmfit:
        fitModel = Model(_function)
        params = dict(zip(polyfits[order]['params'], polyfits[order]['vals']))
        params = fitModel.make_params(**params)
        result = fitModel.fit(y, x=x, params=params)
    else:
        popt, pcov = curve_fit(_function, x, y)

我相信你 post 非常 简化了你的代码版本(因为你当前的版本可以比我上面的代码更有效地最小化)。

我认为 lmfit.models.PolynomialModel() 完全符合您的要求。该模型将多项式次数 n 作为参数并使用名为 c0c1、...、cn 的系数(最多处理 n=7):

from lmfit.models import PolynomialModel

def f(x, y, degree=3):
    fitModel = PolynomialModel(degree=degree)
    params = fitModel.make_params(c0=0, c1=1, c2=1, c3=0, 
                                  c4=0, c5=0, c6=0, c7=0)
    # or if you prefer to do it the hard way:
    params = fitModel.make_params(**{'c%d'%i:0 for i in range(degree+1)})

    return fitModel.fit(y, x=x, params=params)

请注意,在这里可以过度指定系数。也就是说,如果 degree=3,对 fitModel.make_params(c0=0, ..., c7=0) 的调用实际上不会为 c4c5c6c7 生成参数。

PolynomialModel 如果 degree > 7 会引发 TypeError,所以我把你的明确测试留了下来。

我希望这能让您入门,但看起来您可能还想包含其他模型函数。在这种情况下,我所做的是制作一个 class 名称的字典:

from lmfit.models import LinearModel, PolynomialModel, GaussianModel, ....

KnownModels = {'linear': LinearModel, 'polynomial': PolynomialModel, 
              'gaussian': GaussianModel, ...}

然后用它来构建模型:

modelchoice = 'linear' # probably really came from user selection in a GUI

if modelchoice in KnownModels:
    model = KnownModels[modelchoice]()
else:
    raise ValueError("unknown model '%s'" % modelchoice)

params = model.make_params(....) # <- might know and store what the parameter names are
.....