2017-11-11 51 views
0

使用Spark_sklearn執行嵌套交叉驗證GridSearchCV作爲內部cv和sklearn cross_validate/cross_val_score作爲外部cv結果「看起來您試圖從廣播變量引用SparkContext ,行動或轉型「的錯誤。使用Spark_sklearn進行嵌套交叉驗證GridSearchCV產生SPARK-5063錯誤

inner_cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42) 
outer_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42) 
scoring_metric = ['roc_auc', 'average_precision', 'precision'] 
gs = GridSearchCV(sparkcontext, estimator=RandomForestClassifier(
              class_weight='balanced_subsample', n_jobs=-1), 
        param_grid=[{"max_depth": [5], "max_features": [.5, .8], 
           "min_samples_split": [2], "min_samples_leaf": [1, 2, 5, 10], 
           "bootstrap": [True, False], "criterion": ["gini", "entropy"], 
           "n_estimators": [300]}], 
        scoring=scoring_metric, cv=inner_cv, verbose=verbose, n_jobs=-1, 
        refit='roc_auc', return_train_score=False) 
scores = cross_validate(gs, X, y, cv=outer_cv, scoring=scoring_metric, n_jobs=-1, 
         return_train_score=False) 

我試圖做n_jobs=-1n_jobs=1刪除基於JOBLIB並行,然後再試一次,但它仍然產生同樣的異常。

異常:您似乎試圖從廣播變量,操作或轉換引用SparkContext。 SparkContext只能在驅動程序上使用,而不能在其上運行的代碼中使用。有關更多信息,請參閱SPARK-5063。

Complete Traceback (most recent call last): 
    File "model_evaluation.py", line 350, in <module> 
    main() 
    File "model_evaluation.py", line 269, in main 
    scores = cross_validate(gs, X, y, cv=outer_cv, scoring=scoring_metric, n_jobs=-1, return_train_score=False) 
    File "../python27/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 195, in cross_validate 
    for train, test in cv.split(X, y, groups)) 
    File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__ 
    while self.dispatch_one_batch(iterator): 
    File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 620, in dispatch_one_batch 
    tasks = BatchedCalls(itertools.islice(iterator, batch_size)) 
    File "../python27/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 127, in __init__ 
    self.items = list(iterator_slice) 
    File "../python27/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 195, in <genexpr> 
    for train, test in cv.split(X, y, groups)) 
    File "../python27/lib/python2.7/site-packages/sklearn/base.py", line 61, in clone 
    new_object_params[name] = clone(param, safe=False) 
    File "../python27/lib/python2.7/site-packages/sklearn/base.py", line 52, in clone 
    return copy.deepcopy(estimator) 
    File "/usr/local/lib/python2.7/copy.py", line 182, in deepcopy 
    rv = reductor(2) 
    File "/usr/local/lib/spark/python/pyspark/context.py", line 279, in __getnewargs__ 
    "It appears that you are attempting to reference SparkContext from a broadcast " 
Exception: It appears that you are attempting to reference SparkContext from a broadcast 
variable, action, or transformation. SparkContext can only be used on the driver, not 
in code that it run on workers. For more information, see SPARK-5063. 

編輯: 看來問題是sklearn cross_validate()克隆估計每個適合這是不允許的PySpark GridsearchCV估計類似酸洗估計對象的方式,因爲SparkContext()對象不能/不應該被醃製。那麼我們如何正確地克隆估計器?

回答

0

我終於想出了一個解決方案。 scikit-learn clone()函數嘗試深度複製SparkContext對象時會發生此問題。我使用的解決方案有點怪異,如果有更好的解決方案,我肯定會採取另一種方式,但它的工作原理。導入複製類並覆蓋deepcopy()函數,以便在看到SparkContext對象時簡單地忽略它。

# Mock the deep-copy function to ignore copying sparkcontext objects 
# Helps avoid pickling error or broadcast variable errors 
import copy 
_deepcopy = copy.deepcopy 

def mock_deepcopy(*args, **kwargs): 
    if isinstance(args[0], SparkContext): 
     return args[0] 
    return _deepcopy(*args, **kwargs) 

copy.deepcopy = mock_deepcopy 

所以現在不會嘗試複製SparkContext對象和所有似乎正常工作。