线性回归负载模型未按预期进行预测

Linear regression load model doesn't predict as expected

我已经使用 sklearn 训练了一个线性回归模型,获得了 5 星评级,它已经足够好了。我使用 Doc2vec 创建了我的向量,并保存了那个模型。然后我将线性回归模型保存到另一个文件。我想要做的是加载 Doc2vec 模型和线性回归模型并尝试预测另一篇评论。

这个预测有一些非常奇怪的地方:无论输入什么,它总是预测 2.1-3.0 左右。

事实是,我有一个建议,它预测的平均值约为 5(即 2.5 +/-),但事实并非如此。我在训练模型时打印了测试数据的预测值和实际值,它们的范围通常为 1-5。所以我的想法是,代码的加载部分有问题。这是我的加载代码:

from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from bs4 import BeautifulSoup
from joblib import dump, load
import pickle
import re

model = Doc2Vec.load('../vectors/750000/doc2vec_model')

def cleanText(text):
    text = BeautifulSoup(text, "lxml").text
    text = re.sub(r'\|\|\|', r' ', text) 
    text = re.sub(r'http\S+', r'<URL>', text)
    text = re.sub(r'[^\w\s]','',text)
    text = text.lower()
    text = text.replace('x', '')
    return text

review = cleanText("Horrible movie! I don't recommend it to anyone!").split()
vector = model.infer_vector(review)

pkl_filename = "../vectors/750000/linear_regression_model.joblib"
with open(pkl_filename, 'rb') as file:  
    linreg = pickle.load(file)

review_vector = vector.reshape(1,-1)
predict_star = linreg.predict(review_vector)
print(predict_star)

(更新: 我忽略了 .cleanText() 之后问题代码中进行的 .split() 标记化,所以这不是真正的问题。但是保留答案以供参考,因为真正的问题是在评论中发现的。)

当用户向 infer_vector() 提供纯字符串时,用户通常会从 Doc2Vec 得到非常弱的结果。 Doc2Vec infer_vector() 需要一个单词列表, 不是 一个字符串。

如果提供一个字符串,该函数会将其视为单字符单词列表 – 根据 Python 将字符串建模为字符列表和字符类型混合和一个字符串。模型可能不知道这些单字符词中的大部分,而那些可能是 'i''a' 等的词意义不大。所以推断的文档向量将是弱的和无意义的。 (而且,这样一个向量,输入到您的线性回归中,总是给出一个中等的预测值也就不足为奇了。)

如果您将文本分成预期的单词列表,您的结果应该会有所改善。

但更一般地说,提供给 infer_vector() 的词应该被预处理和标记化 完全 然而训练文档是。

(判断您是否正确进行推理的一个公平合理性测试是为您的一些训练文档推断向量,然后向 Doc2Vec 模型询问最接近这些重新推断向量的文档标签.一般来说,同一个文档的training-time tag/ID应该是top result,或者至少是top few的一个,如果不是,可能是数据、模型参数或者inference有其他问题.)

您的示例代码显示了 joblib.dumpjoblib.load 的导入——尽管本节选中均未使用。而且,您文件的后缀暗示模型最初可能是用 joblib.dump() 保存的,而不是香草泡菜。

但是,此代码显示文件仅通过纯 pickle.load() 加载——这可能是错误的来源。

The joblib.load() docs 建议它的 load() 可能会做一些事情,比如从它自己 dump() 创建的多个独立文件中加载 numpy 数组。 (奇怪的是,dump() 文档对此不太清楚,但据推测 dump() 有一个 return 值,它可能是文件名的 list。 )

您可以检查文件的保存位置以查找似乎相关的额外文件,并尝试使用 joblib.load() 而不是普通 pickle,看看它是否加载了您的 more-functional/more-complete 版本linreg 对象。