如何腌制sklearn Pipeline对象?
How to pickle sklearn Pipeline object?
我正在尝试保存管道。我不能。这是我的 class 对象,我尝试过酸洗。
class SentimentModel():
def __init__(self,model_instance,x_train,x_test,y_train,y_test):
import string
from nltk import ngrams
self.ngrams = ngrams
self.string = string
self.model = model_instance
self.x_train = x_train
self.x_test = x_test
self.y_train = y_train
self.y_test = y_test
self._fit()
def _fit(self):
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
self.pipeline = Pipeline([
('bow', CountVectorizer(analyzer=self._text_process)),
('tfidf', TfidfTransformer()),
('classifier', self.model),
])
self.pipeline.fit(self.x_train,self.y_train)
self.preds = self.pipeline.predict(self.x_test)
def _text_process(self,text):
def remove_non_ascii(text):
return ''.join(i for i in text if ord(i)<128)
text = remove_non_ascii(text)
text = [char.lower() for char in text if char not in self.string.punctuation]
text = ''.join(text)
unigrams = [word for word in text.split()]
bigrams = [' '.join(g) for g in self.ngrams(unigrams,2)]
trigrams = [' '.join(g) for g in self.ngrams(unigrams,3)]
tokens = []
tokens.extend(unigrams+bigrams+trigrams)
return tokens
def predict(self,observation):
return self.pipeline.predict(observation)
我收到这些错误:
from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB()
nb_model = SentimentModel(nb,X_train,X_test,y_train,y_test)
import pickle
with open('nb_model1.pkl','wb') as f:
pickle.dump(nb_model,f)
>>>
TypeError: can't pickle module objects
同样:
with open('nb_model1.pkl','wb') as f:
pickle.dump(nb_model.pipeline,f)
TypeError: can't pickle module objects
但是我可以保存 nb_model.model
。但不是管道对象。怎么解释?如何让我的整个管道持续存在?
我见过 ,但问题是,它无法 pickle bow
属性。
joblib.dump(nb_model.pipeline.get_params()['tfidf'], 'nb_tfidf.pkl') # pass
joblib.dump(nb_model.pipeline.get_params()['bow'], 'nb_bow.pkl') # fail
joblib.dump(nb_model.pipeline.get_params()['classifier'], 'nb_classifier.pkl') #pass
>>>
TypeError: can't pickle module objects
我该怎么办?
再试一次,不要在 class 定义中导入模块。这不是一个好的做法,因为当你导入诸如 import string
之类的东西时,你将一整套第三方代码带到你的代码中,而这些代码甚至可能没有安装在另一台想要使用这个 pickle 的机器上;这不是一个好习惯。也许pickle
是为了保护你做这种事。
我正在尝试保存管道。我不能。这是我的 class 对象,我尝试过酸洗。
class SentimentModel():
def __init__(self,model_instance,x_train,x_test,y_train,y_test):
import string
from nltk import ngrams
self.ngrams = ngrams
self.string = string
self.model = model_instance
self.x_train = x_train
self.x_test = x_test
self.y_train = y_train
self.y_test = y_test
self._fit()
def _fit(self):
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
self.pipeline = Pipeline([
('bow', CountVectorizer(analyzer=self._text_process)),
('tfidf', TfidfTransformer()),
('classifier', self.model),
])
self.pipeline.fit(self.x_train,self.y_train)
self.preds = self.pipeline.predict(self.x_test)
def _text_process(self,text):
def remove_non_ascii(text):
return ''.join(i for i in text if ord(i)<128)
text = remove_non_ascii(text)
text = [char.lower() for char in text if char not in self.string.punctuation]
text = ''.join(text)
unigrams = [word for word in text.split()]
bigrams = [' '.join(g) for g in self.ngrams(unigrams,2)]
trigrams = [' '.join(g) for g in self.ngrams(unigrams,3)]
tokens = []
tokens.extend(unigrams+bigrams+trigrams)
return tokens
def predict(self,observation):
return self.pipeline.predict(observation)
我收到这些错误:
from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB()
nb_model = SentimentModel(nb,X_train,X_test,y_train,y_test)
import pickle
with open('nb_model1.pkl','wb') as f:
pickle.dump(nb_model,f)
>>>
TypeError: can't pickle module objects
同样:
with open('nb_model1.pkl','wb') as f:
pickle.dump(nb_model.pipeline,f)
TypeError: can't pickle module objects
但是我可以保存 nb_model.model
。但不是管道对象。怎么解释?如何让我的整个管道持续存在?
我见过 bow
属性。
joblib.dump(nb_model.pipeline.get_params()['tfidf'], 'nb_tfidf.pkl') # pass
joblib.dump(nb_model.pipeline.get_params()['bow'], 'nb_bow.pkl') # fail
joblib.dump(nb_model.pipeline.get_params()['classifier'], 'nb_classifier.pkl') #pass
>>>
TypeError: can't pickle module objects
我该怎么办?
再试一次,不要在 class 定义中导入模块。这不是一个好的做法,因为当你导入诸如 import string
之类的东西时,你将一整套第三方代码带到你的代码中,而这些代码甚至可能没有安装在另一台想要使用这个 pickle 的机器上;这不是一个好习惯。也许pickle
是为了保护你做这种事。