2016-12-10 25 views
0

我想繪製刪除樣本(行)的效果。有人稱之爲「學習曲線」。如何發送數據幀到scikit進行交叉驗證?

所以我想使用熊貓來刪除一些行。 How to remove, randomly, rows from a dataframe but from each label?

但是,當我想要做的交叉驗證,我得到以下錯誤(即使使用df.values把數據框到一個數組後):

enter image description here

所以,我是什麼做錯了?

這裏是我的代碼:

import pandas as pd 
import numpy as np 
from sklearn.model_selection import StratifiedShuffleSplit 
from sklearn import neighbors 
from sklearn import cross_validation 

df = pd.DataFrame(np.random.rand(12, 5)) 
label = np.array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 
df['label'] = label 

df1 = pd.concat(g.sample(2) for idx, g in df.groupby('label')) 

X = df1[[0, 1, 2, 3, 4]].values 
y = df1.label.values 
print(X) 
print(y) 

clf = neighbors.KNeighborsClassifier() 
sss = StratifiedShuffleSplit(1, test_size=0.1) 
scoresSSS = cross_validation.cross_val_score(clf, X, y, cv=sss) 
print(scoresSSS) 

回答

1

馬上蝙蝠,與sss = StratifiedShuffleSplit(n_splits=1, test_size=0.35)你生成一個對象,而不是一個可迭代:

>>> type(sss) 
    <class 'sklearn.model_selection._split.StratifiedShuffleSplit'> 

而不是給StratifiedShuffleSplit類你的整個對象(這顯然是不可迭代的,因此錯誤),你需要給它的類的.split()方法(docs)的火車/測試輸出。另外,StratifiedShuffleSplit類中的test_size參數太小。如果您使用0.1,則會拋出ValueError,因爲您有3個獨特的類,因此測試大小的0.1不會。最後,您在KNeighbors clf對象中使用默認的n_neighbors參數值。使用如此小的數據集時,此默認值太大。由於n_neighbors <= n_samples,使用你所擁有的將會拋出另一個ValueError。所以在我下面的例子我已經調升測試規模在StratifiedShuffleSplit對象,下降n_neighbors下降到2,並通過了iterables從sss.split(X, y)cross_validation.cross_val_scorecv PARAM。

因此,這裏是你希望你的代碼是什麼樣子:

import pandas as pd 
import numpy as np 
from sklearn.model_selection import StratifiedShuffleSplit 
from sklearn import neighbors 
from sklearn import cross_validation 

df = pd.DataFrame(np.random.rand(12, 5)) 
label=np.array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 
df['label'] = label 

df1 = pd.concat(g.sample(2) for idx, g in df.groupby('label')) 


X = df1[[0,1,2,3,4]].values 
y = df1.label.values 

clf = neighbors.KNeighborsClassifier(n_neighbors=2) 
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.35) 

scoresSSS = cross_validation.cross_val_score(clf, X, y, cv=sss.split(X, y)) 
print(scoresSSS) 

我只想說,我不知道比分你正在尋找得到的,並絕不是我在聲稱這將優化你的分數。但是,這將幫助您擺脫這些錯誤,以便您可以重新開始工作。

相關問題