2016-01-26 51 views
6

使用sklearn的StratifiedKFold函數,有人可以幫我理解錯誤嗎?我的猜測是它與我的輸入數組標籤有關,我注意到當我打印它們時(本例中的前16個),索引從0變爲15,但在上面打印了額外的0我沒有預料到。也許我只是一個蟒蛇noob,但看起來很奇怪。StratifiedKFold:IndexError:數組索引太多

有人在這裏看到這個鬼混嗎?

文檔:http://scikit-learn.org...StratifiedKFold.html

代碼:

import nltk 
import sklearn 

print('The nltk version is {}.'.format(nltk.__version__)) 
print('The scikit-learn version is {}.'.format(sklearn.__version__)) 

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape 
print skew_gendata_targets.head(16) 

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 

結果

The nltk version is 3.1. 
The scikit-learn version is 0.17. 
<type 'numpy.ndarray'> (500L, 1L) 
    0 
0 0 
1 0 
2 0 
3 0 
4 0 
5 0 
6 0 
7 0 
8 0 
9 0 
10 0 
11 0 
12 0 
13 0 
14 1 
15 0 
--------------------------------------------------------------------------- 
IndexError        Traceback (most recent call last) 
<ipython-input-373-653b6010b806> in <module>() 
     8 print skew_gendata_targets.head(16) 
     9 
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 
    11 
    12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')' 

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state) 
    531   for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)): 
    532    for label, (_, test_split) in zip(unique_labels, per_label_splits): 
--> 533     label_test_folds = test_folds[y == label] 
    534     # the test split can be too big because we used 
    535     # KFold(max(c, self.n_folds), self.n_folds) instead of 

IndexError: too many indices for array 

回答

11

檢查的skew_gendata_targets.values形狀。您會發現它不是StratifiedKFold期望的1d數組(形狀(500)),而是(500,1)數組。 SKlearn將這些分開處理,而不是強迫它們相同。讓我知道如果有幫助

+0

打印輸出在問題的輸出:print type(skew_gendata_targets.values),skew_gendata_targets.values.shape,它是一個(500,1)numpy數組。我是一個扔進蟒蛇坑的matlab迷,不知道500x1和500xnada矩陣/數組/東西之間的區別。至少在matlab世界中沒有區別。 –

+2

是的 - 它的不幸和有點混亂。執行諸如'*'之類的操作時,差異很重要。在一個案例中,Pandas/numpy將進行元素方式的乘法運算,而另一方面則會進行矩陣乘法運算。希望StratifiedKFold操作在將其強制爲(500)數組後強制執行。 – Brian

+1

我看到,重塑matricies是一個matlaber可以理解的東西,這似乎已經解決了它:np.reshape(skew_gendata_targets.values,[500,]),謝謝! –