0

我在一個虛構的機器學習問題中使用了python和scikit-learn的樹分類器。我有二元結果變量(wc_measure),我相信它依賴於其他一些變量(cash,crisisindustry)。我試過如下:如何在多個屬性上機器學習(樹)?

# import neccessary packages 
import pandas as pd 
import numpy as np 
import sklearn as skl 
from sklearn import tree 
from sklearn.cross_validation import train_test_split as tts 


# import data and give a little overview 
sample = pd.read_stata('sample_data.dta') 

s = sample 


# What I want to learn on 
X = [s.crisis, s.cash, s.industry] 
y = s.wc_measure 
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5) 


#let's learn a little 

my_tree = tree.DecisionTreeClassifier() 
clf = my_tree.fit(X_train, y_train) 
predictions = my_tree.predict(X_test) 

我得到以下錯誤:Number of labels=50 does not match number of samples=1。如果我將X基於單個變量(例如X = s.crisis),我需要重塑X。我不完全明白爲什麼我有這些問題......想法?


PS:這是打印的回報(X)

[0  4.0 
1  4.0 
2  5.0 
3  3.0 
4  4.0 
5  2.0 
6  2.0 
7  1.0 
8  3.0 
9  3.0 
10 4.0 
11 3.0 
12 2.0 
13 4.0 
14 5.0 
15 4.0 
16 2.0 
17 2.0 
18 3.0 
19 2.0 
20 5.0 
21 4.0 
22 2.0 
23 4.0 
24 5.0 
25 1.0 
26 5.0 
27 3.0 
28 4.0 
29 2.0 
    ... 
70 1.0 
71 4.0 
72 4.0 
73 1.0 
74 4.0 
75 3.0 
76 4.0 
77 2.0 
78 2.0 
79 5.0 
80 2.0 
81 3.0 
82 5.0 
83 4.0 
84 4.0 
85 5.0 
86 3.0 
87 3.0 
88 4.0 
89 2.0 
90 2.0 
91 3.0 
92 3.0 
93 4.0 
94 3.0 
95 1.0 
96 4.0 
97 2.0 
98 3.0 
99 4.0 
Name: crisis, dtype: float32, 0  450.283417 
1  113.472214 
2  11.811784 
3  1007.507446 
4  293.895142 
5  1133.297729 
6  2237.830322 
7  1475.787109 
8  283.363678 
9  626.888794 
10  38.865730 
11  991.999390 
12 1115.746948 
13  373.537231 
14  97.570717 
15  136.079193 
16 2560.691406 
17  667.062073 
18 1378.384521 
19  152.716400 
20  5.779267 
21  481.511566 
22  677.809631 
23  722.521790 
24  32.927990 
25 2504.450928 
26  17.422865 
27  651.585083 
28  549.469177 
29  297.458527 
     ...  
70 1198.370239 
71  471.343933 
72  389.709290 
73 2962.622803 
74  581.519287 
75 1148.822388 
76  67.653664 
77 1346.391602 
78 1764.086914 
79  14.308219 
80  973.152161 
81  552.576904 
82  2.863116 
83  425.520752 
84  321.773682 
85  63.597332 
86 1351.122559 
87  735.856567 
88  745.656677 
89 2784.453125 
90 1438.272705 
91  768.780823 
92  827.021423 
93  591.778015 
94  885.169434 
95 1143.088867 
96  399.816803 
97 1517.454834 
98 1311.692505 
99  533.062561 
Name: cash, dtype: float32, 0  5.0 
1  2.0 
2  3.0 
3  5.0 
4  4.0 
5  3.0 
6  5.0 
7  1.0 
8  1.0 
9  2.0 
10 1.0 
11 5.0 
12 2.0 
13 4.0 
14 6.0 
15 2.0 
16 6.0 
17 2.0 
18 5.0 
19 1.0 
20 3.0 
21 4.0 
22 2.0 
23 6.0 
24 4.0 
25 4.0 
26 3.0 
27 3.0 
28 5.0 
29 1.0 
    ... 
70 2.0 
71 4.0 
72 3.0 
73 6.0 
74 6.0 
75 5.0 
76 1.0 
77 3.0 
78 5.0 
79 4.0 
80 2.0 
81 3.0 
82 2.0 
83 5.0 
84 3.0 
85 5.0 
86 5.0 
87 4.0 
88 6.0 
89 6.0 
90 4.0 
91 3.0 
92 4.0 
93 6.0 
94 3.0 
95 2.0 
96 3.0 
97 4.0 
98 6.0 
99 4.0 

PPS:這是我如何生成在Stata數據:

clear matrix 
clear all 
set more off 

set obs 100 
gen id = _n 


*Basics 
    gen industry = round(runiform()*5+1) 
    gen activity = round(runiform()*5+1) 
    gen crisis = round(runiform()*4+1) 
     egen min_crisis = min(crisis) 
     egen max_crisis = max(crisis) 
     gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis) 

*Company details 
    gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1) 

    gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform() 
     replace revenue = 0 if revenue<0 

    *Working Capital (wc) 
    gen stock = runiform()*0.5*crisis*revenue 
    gen receivables = runiform()*0.5*crisis*revenue 
    gen payables = runiform()*-0.5*crisis*revenue 
     replace payables = 0 if payables < 0 
    gen wc = stock + receivables - payables 
     egen avg_wc = mean(wc), by(industry) 


    *Liquidity 
    gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform() 
     replace loan = 0 if loan<0 
     egen pc_loan = pctile(loan), p(0.2) by(industry) 
     replace loan = 0 if loan<pc_loan 

    gen current_debt = n_crisis * loan + runiform()*100 

    gen cash = (1-n_crisis)*revenue + runiform()*100 


*Measures 

    *WC-measure (binary) 
     gen wc_status = (wc-avg_wc) 
      egen max_wc_status = max(wc_status), by(industry) 
      egen min_wc_status = min(wc_status), by(industry) 
      gen n_wc_status = (wc_status - min_wc_status)/(max_wc_status-min_wc_status) 
    gen wc_measure = round(n_wc_status) 
+1

你能分享sample_data.dta文件? – Xevaquor

+0

是否正確輸入了tts? –

+0

我無法分享它。但是我會上傳一個sata腳本來顯示我如何創建數據。 – Rachel

回答

0

我終於解決了這個問題。問題是我沒有將我的樣本s定義爲一個數組 - 相應地,X是一個列表。謝謝大家的幫助!

這裏是我做過什麼:

# import data and give a little overview 
sample = pd.read_stata('sample_data.dta') 
s = sample 
print(s.shape) 


# Have some mor vars and an array of explanatory vars 


X = np.array((s.crisis, s.cash, s.industry)).reshape(100, 3) 
y = np.array(s.wc_measure) 
X_train, X_test, y_train, y_test = tts(X, y, test_size = .8) 


#let's learn a little 

my_tree = tree.DecisionTreeClassifier() 
clf = my_tree.fit(X_train, y_train) 
predictions = my_tree.predict(X_test) 
1

您需要檢查X是否是tts的正確輸入? X有三行N列。 X應該有N行3個屬性。這就是爲什麼它抱怨數字不匹配。

+0

到目前爲止我還沒有真正明白這一點。我對python完全陌生。我假定X的數據結構爲N行/ 3列。我如何檢查? – Rachel

+1

它應該看起來像數組([[1,2],[3,4],[5,6],[7,9]])。該數組有兩列和四行。請看看這裏。 http://scikit-learn.org/stable/modules/generated/sklearn.cross_validation.train_test_split.html有一個train_test_split的例子。希望這會有所幫助 –

+0

謝謝!我知道如何打印前幾行等,但我如何理解數組結構? – Rachel