我在一個虛構的機器學習問題中使用了python和scikit-learn的樹分類器。我有二元結果變量(wc_measure
),我相信它依賴於其他一些變量(cash
,crisis
和industry
)。我試過如下:如何在多個屬性上機器學習(樹)?
# import neccessary packages
import pandas as pd
import numpy as np
import sklearn as skl
from sklearn import tree
from sklearn.cross_validation import train_test_split as tts
# import data and give a little overview
sample = pd.read_stata('sample_data.dta')
s = sample
# What I want to learn on
X = [s.crisis, s.cash, s.industry]
y = s.wc_measure
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5)
#let's learn a little
my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)
我得到以下錯誤:Number of labels=50 does not match number of samples=1
。如果我將X
基於單個變量(例如X = s.crisis
),我需要重塑X
。我不完全明白爲什麼我有這些問題......想法?
PS:這是打印的回報(X)
[0 4.0
1 4.0
2 5.0
3 3.0
4 4.0
5 2.0
6 2.0
7 1.0
8 3.0
9 3.0
10 4.0
11 3.0
12 2.0
13 4.0
14 5.0
15 4.0
16 2.0
17 2.0
18 3.0
19 2.0
20 5.0
21 4.0
22 2.0
23 4.0
24 5.0
25 1.0
26 5.0
27 3.0
28 4.0
29 2.0
...
70 1.0
71 4.0
72 4.0
73 1.0
74 4.0
75 3.0
76 4.0
77 2.0
78 2.0
79 5.0
80 2.0
81 3.0
82 5.0
83 4.0
84 4.0
85 5.0
86 3.0
87 3.0
88 4.0
89 2.0
90 2.0
91 3.0
92 3.0
93 4.0
94 3.0
95 1.0
96 4.0
97 2.0
98 3.0
99 4.0
Name: crisis, dtype: float32, 0 450.283417
1 113.472214
2 11.811784
3 1007.507446
4 293.895142
5 1133.297729
6 2237.830322
7 1475.787109
8 283.363678
9 626.888794
10 38.865730
11 991.999390
12 1115.746948
13 373.537231
14 97.570717
15 136.079193
16 2560.691406
17 667.062073
18 1378.384521
19 152.716400
20 5.779267
21 481.511566
22 677.809631
23 722.521790
24 32.927990
25 2504.450928
26 17.422865
27 651.585083
28 549.469177
29 297.458527
...
70 1198.370239
71 471.343933
72 389.709290
73 2962.622803
74 581.519287
75 1148.822388
76 67.653664
77 1346.391602
78 1764.086914
79 14.308219
80 973.152161
81 552.576904
82 2.863116
83 425.520752
84 321.773682
85 63.597332
86 1351.122559
87 735.856567
88 745.656677
89 2784.453125
90 1438.272705
91 768.780823
92 827.021423
93 591.778015
94 885.169434
95 1143.088867
96 399.816803
97 1517.454834
98 1311.692505
99 533.062561
Name: cash, dtype: float32, 0 5.0
1 2.0
2 3.0
3 5.0
4 4.0
5 3.0
6 5.0
7 1.0
8 1.0
9 2.0
10 1.0
11 5.0
12 2.0
13 4.0
14 6.0
15 2.0
16 6.0
17 2.0
18 5.0
19 1.0
20 3.0
21 4.0
22 2.0
23 6.0
24 4.0
25 4.0
26 3.0
27 3.0
28 5.0
29 1.0
...
70 2.0
71 4.0
72 3.0
73 6.0
74 6.0
75 5.0
76 1.0
77 3.0
78 5.0
79 4.0
80 2.0
81 3.0
82 2.0
83 5.0
84 3.0
85 5.0
86 5.0
87 4.0
88 6.0
89 6.0
90 4.0
91 3.0
92 4.0
93 6.0
94 3.0
95 2.0
96 3.0
97 4.0
98 6.0
99 4.0
PPS:這是我如何生成在Stata數據:
clear matrix
clear all
set more off
set obs 100
gen id = _n
*Basics
gen industry = round(runiform()*5+1)
gen activity = round(runiform()*5+1)
gen crisis = round(runiform()*4+1)
egen min_crisis = min(crisis)
egen max_crisis = max(crisis)
gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis)
*Company details
gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1)
gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform()
replace revenue = 0 if revenue<0
*Working Capital (wc)
gen stock = runiform()*0.5*crisis*revenue
gen receivables = runiform()*0.5*crisis*revenue
gen payables = runiform()*-0.5*crisis*revenue
replace payables = 0 if payables < 0
gen wc = stock + receivables - payables
egen avg_wc = mean(wc), by(industry)
*Liquidity
gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform()
replace loan = 0 if loan<0
egen pc_loan = pctile(loan), p(0.2) by(industry)
replace loan = 0 if loan<pc_loan
gen current_debt = n_crisis * loan + runiform()*100
gen cash = (1-n_crisis)*revenue + runiform()*100
*Measures
*WC-measure (binary)
gen wc_status = (wc-avg_wc)
egen max_wc_status = max(wc_status), by(industry)
egen min_wc_status = min(wc_status), by(industry)
gen n_wc_status = (wc_status - min_wc_status)/(max_wc_status-min_wc_status)
gen wc_measure = round(n_wc_status)
你能分享sample_data.dta文件? – Xevaquor
是否正確輸入了tts? –
我無法分享它。但是我會上傳一個sata腳本來顯示我如何創建數據。 – Rachel