2017-10-08 115 views
1

這是我的熊貓數據框lots_not_preprocessed_usdscikit學習StratifiedShuffleSplit KeyError異常與指數

<class 'pandas.core.frame.DataFrame'> 
Index: 78718 entries, 2017-09-12T18-38-38-076065 to 2017-10-02T07-29-40-245031 
Data columns (total 20 columns): 
created_year    78718 non-null float64 
price      78718 non-null float64 
........ 
decade     78718 non-null int64 
dtypes: float64(8), int64(1), object(11) 
memory usage: 12.6+ MB 

頭(1):

artist_name_normalized house created_year description exhibited_in exhibited_in_museums height images max_estimated_price min_estimated_price price provenance provenance_estate_of sale_date sale_id sale_title style title width decade 
    key                    
    2017-09-12T18-38-38-076065 NaN c11 1862.0 An Album and a small Quantity of unframed Draw... NaN NaN NaN NaN 535.031166 267.515583 845.349242 NaN NaN 1998-06-21 8033 OILS, WATERCOLOURS & DRAWINGS FROM 18TH - 20TH... watercolor painting An Album and a small Quantity of unframed Draw... NaN 186 

我的腳本:

from sklearn.model_selection import StratifiedShuffleSplit 

split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42) 
for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']): 
    strat_train_set = lots_not_preprocessed_usd.loc[train_index] 
    strat_test_set = lots_not_preprocessed_usd.loc[test_index] 

我越來越錯誤消息

KeyError         Traceback (most recent call last) 
<ipython-input-224-cee2389254f2> in <module>() 
     3 split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42) 
     4 for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']): 
----> 5  strat_train_set = lots_not_preprocessed_usd.loc[train_index] 
     6  strat_test_set = lots_not_preprocessed_usd.loc[test_index] 

...... 

KeyError: 'None of [[32199 67509 69003 ..., 44204 2809 56726]] are in the [index]' 

我的索引似乎有問題(例如, 2017-09-12T18-38-38-076065)我不明白。問題在哪裏?

如果我用另一種分裂它按預期工作:

from sklearn.model_selection import train_test_split 

train_set, test_set = train_test_split(lots_not_preprocessed_usd, test_size=0.2, random_state=42) 
+0

添加'lots_not_preprocessed_usd.head()'更多的澄清 – Dark

回答

2

當您使用.loc你需要傳遞相同指數row_indexer所以使用.iloc當你想使用的,而不是.loc orindary數字索引。在for循環中,train_index和text_index不是日期時間,因爲split.split(X,y)會返回隨機索引數組。

... 
for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']): 
    strat_train_set = lots_not_preprocessed_usd.iloc[train_index] 
    strat_test_set = lots_not_preprocessed_usd.iloc[test_index] 

樣品例如

lots_not_preprocessed_usd = pd.DataFrame({'some':np.random.randint(5,10,100),'decade':np.random.randint(5,10,100)},index= pd.date_range('5-10-15',periods=100)) 

for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']): 

    strat_train_set = lots_not_preprocessed_usd.iloc[train_index] 
    strat_test_set = lots_not_preprocessed_usd.iloc[test_index] 

輸出示例:

strat_train_set.head() 
 
      decade some 
2015-08-02  6  7 
2015-06-14  7  6 
2015-08-14  7  9 
2015-06-25  9  5 
2015-05-15  7  9 

+1

謝謝你,作品 – zinyosrim