如何修复分段函数图的 ValueError

How to fix ValueError for piecewise function plot

我目前正在编写以下 python 脚本,以在指定域 np.linspace(1.0, 3.0) 上绘制给定的分段函数,并 运行 出现一些绘图错误。当我 运行 脚本时,出现以下错误:ValueError: x and y must have same first dimension, but have shapes (50,) and (42,)。我尝试通过 domain = np.linspace(1.0, 3.0, 42) 调整 domain 的长度以匹配 codomain 的长度,但没有成功。下面是我的代码:

import numpy as np
from matplotlib import pyplot as plt


def f(x):

    image = []
    for p in x:
        if p in np.linspace(1.0, 1.5):
            y = 1.000004 * (p - 1.0) + 0.486068 * (p - 1.0) ** 2 - 0.106566 * (p - 1.0) ** 3
            image.append(y)
        elif p in np.linspace(1.5, 2.0):
            y = 0.608198 + 1.406152 * (p - 1.5) + 0.326219 * (p - 1.5) ** 2 - 0.052277 * (p - 1.5) ** 3
            image.append(y)
        elif p in np.linspace(2.0, 2.5):
            y = 1.386294 + 1.693164 * (p - 2.0) + 0.247803 * (p - 2.0) ** 2 - 0.032798 * (p - 2.0) ** 3
            image.append(y)
        elif p in np.linspace(2.5, 3.0):
            y = 2.290727 + 1.916372 * (p - 2.5) + 0.198606 * (p - 2.5) ** 2 - 0.021819 * (p - 2.5) ** 3
            image.append(y)

    return image


domain = np.linspace(1.0, 3.0)
codomain = f(domain)

plt.plot(domain, codomain)
plt.show()

免责声明:这是我第一次 post 访问此站点,因此如果有任何需要调整以获得更好的反馈,请告诉我。

您似乎误解了 np.linspace 在做什么。具体来说,np.linspace(1,1.5) 中的值不保证在 np.linspace(1,3)

我建议阅读 documentation 以了解这是为什么。

您似乎想将原始输入分成四分位数,您可以使用 np.arraysplit(x,4)

import numpy as np
from matplotlib import pyplot as plt

def f(x):

    image = []
    a,b,c,d = np.array_split(x,4)
    for p in x:
        if p in a:
            y = 1.000004 * (p - 1.0) + 0.486068 * (p - 1.0) ** 2 - 0.106566 * (p - 1.0) ** 3
            image.append(y)
        elif p in b:
            y = 0.608198 + 1.406152 * (p - 1.5) + 0.326219 * (p - 1.5) ** 2 - 0.052277 * (p - 1.5) ** 3
            image.append(y)
        elif p in c:
            y = 1.386294 + 1.693164 * (p - 2.0) + 0.247803 * (p - 2.0) ** 2 - 0.032798 * (p - 2.0) ** 3
            image.append(y)
        elif p in d:
            y = 2.290727 + 1.916372 * (p - 2.5) + 0.198606 * (p - 2.5) ** 2 - 0.021819 * (p - 2.5) ** 3
            image.append(y)
        else:print(p)
    return image


domain = np.linspace(1.0, 3.0)
codomain = f(domain)

plt.plot(domain, codomain)
plt.show()