条件曲线符合 scipy?

Conditional curve fit with scipy?

假设我想用一条直线拟合我在熄灯情况下记录的数据。现在我不小心让灯亮了,我的数据从数据点 101 开始有一个恒定的偏移量。

我该如何适应这个?我尝试为 x 添加条件,但出现错误

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

记得取消注释其余代码(以遇到错误)。

import numpy as np
from scipy import optimize
import matplotlib.pyplot as plt

d1 = np.random.normal(0,0.1, 100)
d2 = np.random.normal(3,0.1, 100)

x = np.arange(0,200)
y = np.concatenate((d1,d2))

plt.plot(x, y)

# def line(x, a, b, offset):
#     if x < 101:
#         y = a * x + b
#     else:
#         y = (a * x + b) + offset
#     return y
# 
# popt, pcov = optimize.curve_fit(line, xdata = x, ydata = y)
# 
# plt.plot(x, line(x, *popt), color = "firebrick")
plt.show()

预期输出:

我认为标准技巧是将布尔条件转换为整数因子:

def line(x, a, b, offset):
    return (a * x + b) + offset * (x>100)

您收到该错误的原因是 optimize 正在调用您的 line 函数,方法是向它传递一个 array 值,而不仅仅是一个值。要解决此问题,您的 line 函数必须能够处理一组值。幸运的是,numpy有一个功能可以帮助你。

def line(x, a, b, offset):
    return np.piecewise(x, 
                        [x < 101, x >= 101],
                        [lambda x: a * x + b, lambda x: a * x + b + offset])

我应该注意到它仍然没有收敛,但这是一个不同的问题。