2012-05-15 349 views
110

如何訓練有素的樸素貝葉斯分類保存,並用它來預測數據分類保存到磁盤scikit學習

我從scikit學習網站下面的示例程序:

from sklearn import datasets 
iris = datasets.load_iris() 
from sklearn.naive_bayes import GaussianNB 
gnb = GaussianNB() 
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data) 
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum() 

回答

123

分類器僅僅可以醃製和傾倒像任何其他對象。要繼續例如:

import cPickle 
# save the classifier 
with open('my_dumped_classifier.pkl', 'wb') as fid: 
    cPickle.dump(gnb, fid)  

# load it again 
with open('my_dumped_classifier.pkl', 'rb') as fid: 
    gnb_loaded = cPickle.load(fid) 
+0

就像一個魅力!我試圖使用np.savez並一直加載它,並從未幫助過。非常感謝。 – Kartos

156

您還可以使用joblib.dumpjoblib.load這是在處理數字陣列比默認的Python Pickler會更有效。

JOBLIB包括在scikit學習:

>>> from sklearn.externals import joblib 
>>> from sklearn.datasets import load_digits 
>>> from sklearn.linear_model import SGDClassifier 

>>> digits = load_digits() 
>>> clf = SGDClassifier().fit(digits.data, digits.target) 
>>> clf.score(digits.data, digits.target) # evaluate training error 
0.9526989426822482 

>>> filename = '/tmp/digits_classifier.joblib.pkl' 
>>> _ = joblib.dump(clf, filename, compress=9) 

>>> clf2 = joblib.load(filename) 
>>> clf2 
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0, 
     fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5, 
     n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0, 
     shuffle=False, verbose=0, warm_start=False) 
>>> clf2.score(digits.data, digits.target) 
0.9526989426822482 
+0

但從我的理解流水作品,如果它的一部分工作流。如果我想將模型存儲在磁盤上並停止執行。然後我回來一個星期後,嘗試從磁盤加載模型,它會拋出一個錯誤: – venuktan

+0

如果這是您正在查找的內容,則無法停止並恢復「fit」方法的執行。也就是說,如果您使用相同版本的scikit-learn庫從Python中調用joblib.load,那麼joblib.dump'成功後不應該引發異常。 – ogrisel

+7

如果您使用IPython,請不要使用'--pylab'命令行標誌或'%pylab'魔術,因爲已知隱式命名空間重載會中斷酸洗過程。相反,請使用顯式導入和'%matplotlib inline'魔術。 – ogrisel

49

你要找的是所謂模型持久在sklearn的話,它是在introductionmodel persistence部分記錄。

所以,你必須初始化你的分類,並與

clf = some.classifier() 
clf.fit(X, y) 

之後訓練的很長一段時間,你有兩種選擇:

1)用泡椒

import pickle 
# now you can save it to a file 
with open('filename.pkl', 'wb') as f: 
    pickle.dump(clf, f) 

# and later you can load it 
with open('filename.pkl', 'rb') as f: 
    clf = pickle.load(f) 

2)使用Joblib

from sklearn.externals import joblib 
# now you can save it to a file 
joblib.dump(clf, 'filename.pkl') 
# and later you can load it 
clf = joblib.load('filename.pkl') 

一個更多的時間是有幫助的閱讀上述鏈接

5

在許多情況下,尤其是文本分類是不足夠的存儲分類,但你需要存儲的矢量化,以及以便將來可以矢量化您的輸入。

import pickle 
with open('model.pkl', 'wb') as fout: 
    pickle.dump((vectorizer, clf), fout) 

將來使用情況:

with open('model.pkl', 'rb') as fin: 
    vectorizer, clf = pickle.load(fin) 

X_new = vectorizer.transform(new_samples) 
X_new_preds = clf.predict(X_new) 

在轉儲向量化,可以通過刪除矢量化的stop_words_屬性:

vectorizer.stop_words_ = None 

作出傾銷更爲有效。 此外,如果您的分類器參數是稀疏的(如在大多數文本分類示例中),則可以將參數從密集轉換爲稀疏,這會在內存消耗,加載和轉儲方面產生巨大差異。 Sparsify由模型:

clf.sparsify() 

,它會自動爲SGDClassifier工作,但如果你知道你的模型是稀疏(地段clf.coef_零),那麼你可以手動轉換CLF。coef_通過:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_) 

然後你可以更有效地存儲它。

1

sklearn估計器實現方法,使您可以輕鬆保存估算器的相關訓練屬性。一些估計實現__getstate__方法本身,而是其他人,像GMM只是使用它只是保存對象內部字典中的base implementation

def __getstate__(self): 
    try: 
     state = super(BaseEstimator, self).__getstate__() 
    except AttributeError: 
     state = self.__dict__.copy() 

    if type(self).__module__.startswith('sklearn.'): 
     return dict(state.items(), _sklearn_version=__version__) 
    else: 
     return state 

推薦的方法來保存你的模型到光盤是使用pickle模塊:

from sklearn import datasets 
from sklearn.svm import SVC 
iris = datasets.load_iris() 
X = iris.data[:100, :2] 
y = iris.target[:100] 
model = SVC() 
model.fit(X,y) 
import pickle 
with open('mymodel','wb') as f: 
    pickle.dump(model,f) 

但是,您應該保存額外的數據,以便將來可以重新訓練模型,或遭受可怕的後果(例如鎖定到舊版sklearn)

documentation

訓練數據,例如:

爲了重建與 未來版本scikit學習一個類似的模型,額外的元數據應該沿着醃製 模型保存得分訓練數據獲得一個不變的快照

用於生成模型

的scikit學習的版本和它的依賴

交叉驗證的蟒源代碼的引用

對於依賴於用Cython編寫的tree.pyx模塊(如IsolationForest)的合成估計器尤其如此,因爲它創建了一個耦合到實現在sklearn版本之間不保證穩定。它在過去看到了倒退不相容的變化。

如果您的模型變得非常大並且加載變得麻煩,您還可以使用效率更高的joblib。從文檔:

在scikit的特定情況下,它可能是更有趣使用 JOBLIB的更換picklejoblib.dump & joblib.load),這是 上在內部作爲攜帶大numpy的數組的對象更高效 往往是安裝scikit學習估計的情況下,卻只能 鹹菜到磁盤,而不是一個字符串: