我正在做一個分類任務。不過,我得到的結果稍有不同:如何在scikit-learn中計算correclty交叉驗證分數?
#First Approach
kf = KFold(n=len(y), n_folds=10, shuffle=True, random_state=False)
pipe= make_pipeline(SVC())
for train_index, test_index in kf:
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
print ('Precision',np.mean(cross_val_score(pipe, X_train, y_train, scoring='precision')))
#Second Approach
clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
print ('Precision:', precision_score(y_test, y_pred,average='binary'))
#Third approach
pipe= make_pipeline(SCV())
print('Precision',np.mean(cross_val_score(pipe, X, y, cv=kf, scoring='precision')))
#Fourth approach
pipe= make_pipeline(SVC())
print('Precision',np.mean(cross_val_score(pipe, X_train, y_train, cv=kf, scoring='precision')))
日期:
Precision: 0.780422106837
Precision: 0.782051282051
Precision: 0.801544091998
/usr/local/lib/python3.5/site-packages/sklearn/cross_validation.py in cross_val_score(estimator, X, y, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch)
1431 train, test, verbose, None,
1432 fit_params)
-> 1433 for train, test in cv)
1434 return np.array(scores)[:, 0]
1435
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
798 # was dispatched. In particular this covers the edge
799 # case of Parallel used with an exhausted iterator.
--> 800 while self.dispatch_one_batch(iterator):
801 self._iterating = True
802 else:
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in dispatch_one_batch(self, iterator)
656 return False
657 else:
--> 658 self._dispatch(tasks)
659 return True
660
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in _dispatch(self, batch)
564
565 if self._pool is None:
--> 566 job = ImmediateComputeBatch(batch)
567 self._jobs.append(job)
568 self.n_dispatched_batches += 1
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in __init__(self, batch)
178 # Don't delay the application, to avoid keeping the input
179 # arguments in memory
--> 180 self.results = batch()
181
182 def get(self):
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in __call__(self)
70
71 def __call__(self):
---> 72 return [func(*args, **kwargs) for func, args, kwargs in self.items]
73
74 def __len__(self):
/usr/local/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in <listcomp>(.0)
70
71 def __call__(self):
---> 72 return [func(*args, **kwargs) for func, args, kwargs in self.items]
73
74 def __len__(self):
/usr/local/lib/python3.5/site-packages/sklearn/cross_validation.py in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, error_score)
1522 start_time = time.time()
1523
-> 1524 X_train, y_train = _safe_split(estimator, X, y, train)
1525 X_test, y_test = _safe_split(estimator, X, y, test, train)
1526
/usr/local/lib/python3.5/site-packages/sklearn/cross_validation.py in _safe_split(estimator, X, y, indices, train_indices)
1589 X_subset = X[np.ix_(indices, train_indices)]
1590 else:
-> 1591 X_subset = safe_indexing(X, indices)
1592
1593 if y is not None:
/usr/local/lib/python3.5/site-packages/sklearn/utils/__init__.py in safe_indexing(X, indices)
161 indices.dtype.kind == 'i'):
162 # This is often substantially faster than X[indices]
--> 163 return X.take(indices, axis=0)
164 else:
165 return X[indices]
IndexError: index 900 is out of bounds for size 900
所以,我的問題是上述方法是正確的計算cross validated metrics?我相信我的分數是被污染的,因爲我對什麼時候執行交叉驗證感到困惑。因此,關於如何正確執行交叉驗證分數的任何想法?
UPDATE
在訓練步驟評估?
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = False)
clf = make_pipeline(SVC())
# However, fot clf, you can use whatever estimator you like
kf = StratifiedKFold(y = y_train, n_folds=10, shuffle=True, random_state=False)
scores = cross_val_score(clf, X_train, y_train, cv = kf, scoring='precision')
print('Mean score : ', np.mean(scores))
print('Score variance : ', np.var(scores))
感謝您的幫助,關於到' cross_val_score',你知道我爲什麼得到不同的結果嗎?哪個是計算它們的正確方法? –
它可能是由於隨機種子。你是否嘗試設置random_state參數並查看會發生什麼? –
是的,我的隨機我甚至開始之前做了以下:'np.random.seed(1337)' –