6
使用sklearn的StratifiedKFold函數,有人可以幫我理解錯誤嗎?我的猜測是它與我的輸入數組標籤有關,我注意到當我打印它們時(本例中的前16個),索引從0變爲15,但在上面打印了額外的0我沒有預料到。也許我只是一個蟒蛇noob,但看起來很奇怪。StratifiedKFold:IndexError:數組索引太多
有人在這裏看到這個鬼混嗎?
文檔:http://scikit-learn.org...StratifiedKFold.html
代碼:
import nltk
import sklearn
print('The nltk version is {}.'.format(nltk.__version__))
print('The scikit-learn version is {}.'.format(sklearn.__version__))
print type(skew_gendata_targets.values), skew_gendata_targets.values.shape
print skew_gendata_targets.head(16)
skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
結果
The nltk version is 3.1.
The scikit-learn version is 0.17.
<type 'numpy.ndarray'> (500L, 1L)
0
0 0
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
9 0
10 0
11 0
12 0
13 0
14 1
15 0
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-373-653b6010b806> in <module>()
8 print skew_gendata_targets.head(16)
9
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
11
12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')'
d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
531 for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
532 for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 533 label_test_folds = test_folds[y == label]
534 # the test split can be too big because we used
535 # KFold(max(c, self.n_folds), self.n_folds) instead of
IndexError: too many indices for array
打印輸出在問題的輸出:print type(skew_gendata_targets.values),skew_gendata_targets.values.shape,它是一個(500,1)numpy數組。我是一個扔進蟒蛇坑的matlab迷,不知道500x1和500xnada矩陣/數組/東西之間的區別。至少在matlab世界中沒有區別。 –
是的 - 它的不幸和有點混亂。執行諸如'*'之類的操作時,差異很重要。在一個案例中,Pandas/numpy將進行元素方式的乘法運算,而另一方面則會進行矩陣乘法運算。希望StratifiedKFold操作在將其強制爲(500)數組後強制執行。 – Brian
我看到,重塑matricies是一個matlaber可以理解的東西,這似乎已經解決了它:np.reshape(skew_gendata_targets.values,[500,]),謝謝! –