2013-07-07 33 views
33

我使用基於scikit-learn的nolearn的DBN(深層信念網絡)。python scikit-learn:導出受過訓練的分類器

我已經建立了一個網絡,可以很好地對我的數據進行分類,現在我對導出部署模型感興趣,但我不知道如何(每次我想要預測某事時我都在訓練DBN)。在matlab中,我只會導出權重矩陣並將其導入另一臺機器中。

是否有人知道如何導出模型/要輸入的權重矩陣,而無需重新訓練整個模型?

+2

您是否嘗試過使用[pickle](http://docs.python.org/2/library/pickle.html)模塊簡單地序列化模型? – ffriend

+0

@朋友 - 不,但我會嘗試。謝謝! – jcdmb

回答

54

您可以使用:

>>> from sklearn.externals import joblib 
>>> joblib.dump(clf, 'my_model.pkl', compress=9) 

再後來,預測服務器上:

>>> from sklearn.externals import joblib 
>>> model_clone = joblib.load('my_model.pkl') 

這基本上是一個Python泡菜大宗numpy的陣列的優化處理。它與常規泡菜w.r.t有相同的限制。代碼更改:如果pickle對象的類結構發生更改,則可能不再能夠使用nolearn或scikit-learn的新版本取消對象。

如果您想要長期穩健地存儲模型參數,您可能需要編寫自己的IO層(例如,使用二進制格式序列化工具,如協議緩衝區或avro或低效但可移植的text/json/xml表示如PMML)。

+1

我用'joblib.dump(clf,'my_model.pkl',compress = 9)''得到'RuntimeError:超過最大遞歸深度''。 –

+1

查看https://github.com/dnouri/nolearn/issues/271 –

2

scikit-learn文檔中的3.4. Model persistence部分幾乎涵蓋了所有內容。

除了sklearn.externals.joblib ogrisel指出,這顯示瞭如何使用常規的泡菜包:

>>> from sklearn import svm 
>>> from sklearn import datasets 
>>> clf = svm.SVC() 
>>> iris = datasets.load_iris() 
>>> X, y = iris.data, iris.target 
>>> clf.fit(X, y) 
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0, 
    kernel='rbf', max_iter=-1, probability=False, random_state=None, 
    shrinking=True, tol=0.001, verbose=False) 

>>> import pickle 
>>> s = pickle.dumps(clf) 
>>> clf2 = pickle.loads(s) 
>>> clf2.predict(X[0]) 
array([0]) 
>>> y[0] 
0 

,並給出了一些警告,如保存在scikit學習另一種可能無法加載的一個版本車型版。

7

酸洗/取出有一個缺點,它只適用於匹配python版本(主要和可能還有次要版本)和sklearn,joblib庫版本。

機器學習模型還有其他一些描述性輸出格式,例如由Data Mining Group開發的,如預測模型標記語言(PMML)和便攜式分析格式(PFA)。其中,PMML是much better supported

因此,您可以選擇將模型從scikit-learn保存到PMML中(例如使用sklearn2pmml),然後使用jpmml(當然您有更多選擇)在java,spark或hive中部署和運行它。

+0

看起來不錯,但如果部署也是基於Python的呢?有沒有'pmml2sklearn'? – VillasV