2017-07-12 38 views
5

我構建了一個scikit-learn模型,我想在日常的python cron作業中重用(NB:沒有涉及其他平臺 - 沒有R,沒有Java & c)。scikit-learn模型持久性:pickle vs pmml vs ...?

pickled它(實際上,我醃我自己的對象,其中一個字段是GradientBoostingClassifier),我在cron作業中取消了它。迄今爲止這麼好(並已在Save classifier to disk in scikit-learnModel persistence in Scikit-Learn?中討論過)。

不過,我升級sklearn現在我得到這些警告:

.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator DecisionTreeRegressor from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk. 
UserWarning) 
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator PriorProbabilityEstimator from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk. 
UserWarning) 
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator GradientBoostingClassifier from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk. 
UserWarning) 

現在我該怎麼辦?

  • 我可以downgrage到0.18.1,並堅持下去,直到我準備重建模型。由於各種原因,我覺得這是不可接受的。

  • 我可以取消醃製該文件並重新醃一遍。這與0.18.2一起工作,但與0.19打破。 NFG。 joblib看起來不會更好。

  • 我希望我可以將數據保存爲獨立於版本的ASCII格式(例如JSON或XML)。顯然,這是最佳解決方案,但似乎有辦法做到這一點(另請參閱Sklearn - model persistence without pkl file)。

  • 我可以將模型保存到PMML,但它的支持是不冷不熱,充其量: 我可以用sklearn2pmml保存模型(儘管不容易),和augustus/lightpmmlpredictor申請(雖然加載)模型。然而,這些都不是直接可用於pip,這使得部署成爲一場噩夢。此外,augustus & lightpmmlpredictor項目似乎已經死亡。 Importing PMML models into Python (Scikit-learn) - 不。

  • 上述變體:使用sklearn2pmml保存PMML,並使用openscoring進行評分。需要與外部進程進行交互。育。

對此有何建議?

回答

2

不同版本的scikit-learn的模型持久性通常是不可能的。原因很明顯:你用一個定義醃製Class1,並且想用其他定義將其取消放入Class2

您可以:

  • 仍然試圖堅持sklearn的一個版本。
  • 忽略警告,並希望對Class1有效的工作也適用於Class2
  • 寫你自己的班,可以序列化你的GradientBoostingClassifier並從這個序列化的形式恢復它,並希望它會比鹹菜更好。

我做了一個例子,說明如何將單個DecisionTreeRegressor轉換爲純粹的列表和字典格式,完全兼容JSON並將其恢復。

import numpy as np 
from sklearn.tree import DecisionTreeRegressor 
from sklearn.datasets import make_classification 

### Code to serialize and deserialize trees 

LEAF_ATTRIBUTES = ['children_left', 'children_right', 'threshold', 'value', 'feature', 'impurity', 'weighted_n_node_samples'] 
TREE_ATTRIBUTES = ['n_classes_', 'n_features_', 'n_outputs_'] 

def serialize_tree(tree): 
    """ Convert a sklearn.tree.DecisionTreeRegressor into a json-compatible format """ 
    encoded = { 
     'nodes': {}, 
     'tree': {}, 
     'n_leaves': len(tree.tree_.threshold), 
     'params': tree.get_params() 
    } 
    for attr in LEAF_ATTRIBUTES: 
     encoded['nodes'][attr] = getattr(tree.tree_, attr).tolist() 
    for attr in TREE_ATTRIBUTES: 
     encoded['tree'][attr] = getattr(tree, attr) 
    return encoded 

def deserialize_tree(encoded): 
    """ Restore a sklearn.tree.DecisionTreeRegressor from a json-compatible format """ 
    x = np.arange(encoded['n_leaves']) 
    tree = DecisionTreeRegressor().fit(x.reshape((-1,1)), x) 
    tree.set_params(**encoded['params']) 
    for attr in LEAF_ATTRIBUTES: 
     for i in range(encoded['n_leaves']): 
      getattr(tree.tree_, attr)[i] = encoded['nodes'][attr][i] 
    for attr in TREE_ATTRIBUTES: 
     setattr(tree, attr, encoded['tree'][attr]) 
    return tree 

## test the code 

X, y = make_classification(n_classes=3, n_informative=10) 
tree = DecisionTreeRegressor().fit(X, y) 
encoded = serialize_tree(tree) 
decoded = deserialize_tree(encoded) 
assert (decoded.predict(X)==tree.predict(X)).all() 

到這一點,你可以去序列化和反序列化整個GradientBoostingClassifier

from sklearn.ensemble import GradientBoostingClassifier 
from sklearn.ensemble.gradient_boosting import PriorProbabilityEstimator 

def serialize_gbc(clf): 
    encoded = { 
     'classes_': clf.classes_.tolist(), 
     'max_features_': clf.max_features_, 
     'n_classes_': clf.n_classes_, 
     'n_features_': clf.n_features_, 
     'train_score_': clf.train_score_.tolist(), 
     'params': clf.get_params(), 
     'estimators_shape': list(clf.estimators_.shape), 
     'estimators': [], 
     'priors':clf.init_.priors.tolist() 
    } 
    for tree in clf.estimators_.reshape((-1,)): 
     encoded['estimators'].append(serialize_tree(tree)) 
    return encoded 

def deserialize_gbc(encoded): 
    x = np.array(encoded['classes_']) 
    clf = GradientBoostingClassifier(**encoded['params']).fit(x.reshape(-1, 1), x) 
    trees = [deserialize_tree(tree) for tree in encoded['estimators']] 
    clf.estimators_ = np.array(trees).reshape(encoded['estimators_shape']) 
    clf.init_ = PriorProbabilityEstimator() 
    clf.init_.priors = np.array(encoded['priors']) 
    clf.classes_ = np.array(encoded['classes_']) 
    clf.train_score_ = np.array(encoded['train_score_']) 
    clf.max_features_ = encoded['max_features_'] 
    clf.n_classes_ = encoded['n_classes_'] 
    clf.n_features_ = encoded['n_features_'] 
    return clf 

# test on the same problem 
clf = GradientBoostingClassifier() 
clf.fit(X, y); 
encoded = serialize_gbc(clf) 
decoded = deserialize_gbc(encoded) 
assert (decoded.predict(X) == clf.predict(X)).all() 

這適用於scikit學習v0.19,但不要問我什麼會來的下一個版本將打破這個代碼。我既不是先知也不是sklearn的開發者。

如果你想完全獨立於新版本的sklearn,最安全的事情就是編寫一個遍歷序列化樹並進行預測的函數,而不是重新創建sklearn樹。

+0

這比鹹菜更可靠嗎? pickle的問題是,如果'sklearn'改變Class定義(例如,刪除或重命名一個槽),我將不得不重寫'serialize_ *'和'deserialize_ *'函數,更重要的是,編寫解串器來轉換將_old_版本序列化爲_new_版本。我同意這可能比鹹菜噩夢更好,但幾乎沒有。 – sds

+0

這不能保證你將與20或200版本的sklearn兼容。但它至少可以讓你更好地控制局勢。例如。如果sklearn完全重寫它的'''ClassificationLossFunction''',你將不會受到影響。 –