2013-11-20 119 views
11

我一直試圖在迴歸樹(或隨機森林迴歸)中使用分類inpust,但sklearn不斷返回錯誤並要求數字輸入。迴歸樹或隨機森林迴歸與分類輸入

import sklearn as sk 
MODEL = sk.ensemble.RandomForestRegressor(n_estimators=100) 
MODEL.fit([('a',1,2),('b',2,3),('a',3,2),('b',1,3)], [1,2.5,3,4]) # does not work 
MODEL.fit([(1,1,2),(2,2,3),(1,3,2),(2,1,3)], [1,2.5,3,4]) #works 

MODEL = sk.tree.DecisionTreeRegressor() 
MODEL.fit([('a',1,2),('b',2,3),('a',3,2),('b',1,3)], [1,2.5,3,4]) # does not work 
MODEL.fit([(1,1,2),(2,2,3),(1,3,2),(2,1,3)], [1,2.5,3,4]) #works 

據我的理解,分類輸入應該可能在這些方法中沒有任何轉換(例如WOE替換)。

有沒有其他人有這個困難?

謝謝!

回答

16

scikit-learn具有用於分類變量沒有專門表示(又名中的R因子),一個可能的解決方案是使用LabelEncoder來編碼字符串作爲int

import numpy as np 
from sklearn.preprocessing import LabelEncoder 
from sklearn.ensemble import RandomForestRegressor 

X = np.asarray([('a',1,2),('b',2,3),('a',3,2),('c',1,3)]) 
y = np.asarray([1,2.5,3,4]) 

# transform 1st column to numbers 
X[:, 0] = LabelEncoder().fit_transform(X[:,0]) 

regressor = RandomForestRegressor(n_estimators=150, min_samples_split=2) 
regressor.fit(X, y) 
print(X) 
print(regressor.predict(X)) 

輸出:

[[ 0. 1. 2.] 
[ 1. 2. 3.] 
[ 0. 3. 2.] 
[ 2. 1. 3.]] 
[ 1.61333333 2.13666667 2.53333333 2.95333333] 

但是記住如果ab是獨立的類別,並且它僅適用於基於樹的估計器,則這是輕微的黑客攻擊。爲什麼?因爲b並不比a大。正確的方法是在LabelEncoderpd.get_dummies之後使用OneHotEncoderX[:, 0]生成兩個單獨的單熱編碼列。

import numpy as np 
from sklearn.preprocessing import LabelEncoder, OneHotEncoder 
from sklearn.ensemble import RandomForestRegressor 

X = np.asarray([('a',1,2),('b',2,3),('a',3,2),('c',1,3)]) 
y = np.asarray([1,2.5,3,4]) 

# transform 1st column to numbers 
import pandas as pd 
X_0 = pd.get_dummies(X[:, 0]).values 
X = np.column_stack([X_0, X[:, 1:]]) 

regressor = RandomForestRegressor(n_estimators=150, min_samples_split=2) 
regressor.fit(X, y) 
print(X) 
print(regressor.predict(X)) 
+1

謝謝你。我不認爲它解決了這個問題, '數字標籤'創建了一個線性級數的假設,這很可能與您試圖預測的結果不符。設想一個決策樹節點,當決定使用例如'<2 and > = 2'的下一個截斷分割時,它與「如果在('a','c')」中沒有相同的意義。 – jpsfer

+0

我誤解了你的問題。我剛纔看到你想把所有事情都視爲絕對的。我會相應地更新示例... – Matt

+0

非常感謝馬特! – jpsfer

1

你必須在Python中手動僞代碼。我建議使用pandas.get_dummies()作爲一個熱門編碼。對於Boosted樹,我使用factorize()實現了Ordinal編碼。

對於這類東西還有一整套包裝here

欲瞭解更多詳細的解釋請看this Data Science Stack Exchange的帖子。