4

我不能爲我的生活弄清楚如何計算rpart上的混淆矩陣。來自rpart的混淆矩陣

這裏是我做了什麼:

set.seed(12345) 
UBANK_rand <- UBank[order(runif(1000)), ] 
UBank_train <- UBank_rand[1:900, ] 
UBank_test <- UBank_rand[901:1000, ] 


dim(UBank_train) 
dim(UBank_test) 

#Build the formula for the Decision Tree 
UB_tree <- Personal.Loan ~ Experience + Age+ Income +ZIP.Code + Family + CCAvg + Education 

#Building the Decision Tree from Test Data 
UB_rpart <- rpart(UB_tree, data=UBank_train) 

現在,我想我會做類似

table(predict(UB_rpart, UBank_test, UBank_Test$Default)) 

但是,這不是給我一個混淆矩陣。

回答

8

您沒有提供一個可重複的例子,所以我將創建一個合成數據集:

set.seed(144) 
df = data.frame(outcome = as.factor(sample(c(0, 1), 100, replace=T)), 
       x = rnorm(100)) 

predict功能爲rpart模型type="class"將返回預測類爲每個觀察。

library(rpart) 
mod = rpart(outcome ~ x, data=df) 
pred = predict(mod, type="class") 
table(pred) 
# pred 
# 0 1 
# 51 49 

最後,您可以通過預測和真實結果之間運行table建立混淆矩陣:

table(pred, df$outcome) 
# pred 0 1 
# 0 36 15 
# 1 14 35 
-1

您可以嘗試

pred <- predict(UB_rpart, UB_test) confusionMatrix(pred, UB_test$Personal.Loan)