使用代数约束和边界最小化最小二乘法

Minimizing Least Squares with Algebraic Constraints and Bounds

我试图根据一些矢量求和来最小化最小二乘和。简而言之,我正在创建一个方程式,它采用理想向量,用确定的系数对它们进行加权,然后对加权向量求和。一旦将此总和与为某些观察发现的实际矢量测量值进行比较,就会得出最小二乘和。

举个例子:

# Observation A has the following measurements:
A = [0, 4.1, 5.6, 8.9, 4.3]

# How similar is A to ideal groups identified by the following:
group1 = [1, 3, 5, 10, 3]
group2 = [6, 3, 2, 1, 10]
group3 = [3, 3, 4, 2, 1]

# Let y be the predicted measurement for A with coefficients s1, s2, and s3:
y = s1 * group1 + s2 * group2 + s3 * group3

# y will be some vector of length 5, similar to A
# Now find the sum of least squares between y and A
sum((y_i - A_i)** 2 for y_i in y for A_i in A)

Necessary bounds and constraints

0 <= s1, s2, s3 <= 1

s1 + s2 + s3 = 1

y = s1 * group1 + s2 * group2 + s3 * group3

我想最小化 y 和 A 的最小二乘和以获得系数 s1、s2、s3,但我很难确定 scipy.optimize 中的正确选择可能是什么是。那里用于最小化最小二乘和的函数似乎无法处理代数变量约束。我正在处理的数据是使用这些矢量化测量值进行的数千次观察。任何想法或想法将不胜感激!

对于您的情况,您可以像这样使用 scipy.optimize 中的 minimize()

minimize(fun=obj_fun, args=argtpl x0=xinit, bounds=bnds, constraints=cons)

其中 obj_fun(x, *args) 是您的 objective 函数,argtpl 是您 objective 函数的(可选)参数元组,xinit 初始点, bnds 变量边界的元组列表和 cons 约束的字典列表。

import numpy as np
from scipy.optimize import minimize

# Observation A has the following measurements:
A = np.array([0, 4.1, 5.6, 8.9, 4.3])
# How similar is A to ideal groups identified by the following:
group1 = np.array([1, 3, 5, 10, 3])
group2 = np.array([6, 3, 2, 1, 10])
group3 = np.array([3, 3, 4, 2, 1])

# Define the objective function
# x is the array containing your wanted coefficients
def obj_fun(x, A, g1, g2, g3):
    y = x[0] * g1 + x[1] * g2 + x[2] * g3
    return np.sum((y-A)**2)

# Bounds for the coefficients
bnds = [(0, 1), (0, 1), (0, 1)]
# Constraint: x[0] + x[1] + x[2] - 1 = 0
cons = [{"type": "eq", "fun": lambda x: x[0] + x[1] + x[2] - 1}]

# Initial guess
xinit = np.array([1, 1, 1])
res = minimize(fun=obj_fun, args=(A, group1, group2, group3), x0=xinit, bounds=bnds, constraints=cons)
print(res.x)

您的示例的解决方案:

array([9.25609756e-01, 7.43902439e-02, 6.24242179e-12])