在 matplotlib 中显示一条线,以便我的结果与我的教科书相匹配

Display a line in matplotlib, so that my results match up with my textbook

我正在阅读 ML 教科书:使用 scikit-learn 掌握机器学习,虽然我的代码给出了正确答案,但与书中的内容不符。

首先它给了我这个代码:

import matplotlib.pyplot as plt
X = [[6], [8], [10], [14],   [18]]
y = [[7], [9], [13], [17.5], [18]]
plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.plot(X, y, 'k.')
plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()

这在 matplotlib 中给了我这个图表:

这与我的结果相符。

但是,在下一步中它给了我这个代码:

from sklearn.linear_model import LinearRegression
# Training data
X = [[6], [8], [10], [14],   [18]]
y = [[7], [9], [13], [17.5], [18]]
# Create and fit the model
model = LinearRegression()
model.fit(X, y)
print 'A 12" pizza should cost: $%.2f' % model.predict([12])[0]

这张图表:

那个chard和我的代码不匹配,它没有matplotlib chart-maker函数。我尝试阅读指南并制作自己的指南:

from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt


X = [[6], [8], [10], [14],   [18]]
y = [[7], [9], [13], [17.5], [18]]


model = LinearRegression()
model.fit(X, y)

z = np.array([12]).reshape(-1,1)

print ('A 12" pizza should cost: $%.2f' % model.predict(z)[0])
print ("\n" + "_" * 50 + "\n")

plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.plot(X, y, z, 'k.')
plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()

但这只给了我这个奇怪的蓝色东西:

我是 python 中的数学新手,所以如果有人能给我更多关于如何解决这个问题的信息,我将不胜感激。

你得到的这个"weird blue thing"是由线段连接在一起的数据;你的数据应该使用 plt.scatter 绘制,这会给你一个点云。

您对回归线的计算是正确的,需要解决的是如何在您的数据集上绘制该线:

拟合数据后,您需要提取绘制回归线所需的值;您需要的数据是 x 轴两端的两个点(此处为 x=0x=25)。如果我们对这两个值调用 model.predict,我们将获得相应的预测。这些 x 值加上它们相应的预测形成了我们将用来绘制直线的两个点。

首先我们提取预测值y0y25。然后我们使用 plt.plot 与点 (0, y0) 和 (25, y25) 绘制绿色回归线。

from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt


X = [[6], [8], [10], [14],   [18]]
y = [[7], [9], [13], [17.5], [18]]


model = LinearRegression()
model.fit(X, y)

z = np.array([12]).reshape(-1,1)

print ('A 12" pizza should cost: $%.2f' % model.predict(z)[0])
print ("\n" + "_" * 50 + "\n")

plt.figure()
plt.title('Pizza price plotted against diameter')
plt.xlabel('Diameter in inches')
plt.ylabel('Price in dollars')
plt.scatter(X, y, z, 'k')

y0, y25 = model.predict(0)[0][0], model.predict(25)[0][0]
plt.plot((0, 25), (y0, y25), 'g')

plt.axis([0, 25, 0, 25])
plt.grid(True)
plt.show()