2016-04-20 66 views
5

我想檢查所有到達rpart決策樹中某個節點的觀測值。例如,在下面的代碼:在rpart節點中獲取觀測數據(即:CART)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit 

n= 81 

node), split, n, loss, yval, (yprob) 
     * denotes terminal node 

1) root 81 17 absent (0.79.20987654) 
    2) Start>=8.5 62 6 absent (0.90322581 0.09677419) 
    4) Start>=14.5 29 0 absent (1.00000000 0.00000000) * 
    5) Start< 14.5 33 6 absent (0.81818182 0.18181818) 
     10) Age< 55 12 0 absent (1.00000000 0.00000000) * 
     11) Age>=55 21 6 absent (0.71428571 0.28571429) 
     22) Age>=111 14 2 absent (0.85714286 0.14285714) * 
     23) Age< 111 7 3 present (0.42857143 0.57142857) * 
    3) Start< 8.5 19 8 present (0.42105263 0.57894737) * 

我想看看在節點的所有觀測(5)(即:33個觀測爲其開始> = 8.5 &開始< 14.5)。很明顯,我可以手動去他們。但我想有一些像(比如說)「get_node_date」的函數。爲此我可以運行get_node_date(5) - 並獲得相關的觀察結果。

有關如何去解決這個問題的任何建議?

回答

1

似乎沒有這樣的功能,這使得觀測從特定節點提取。我將按如下方式解決它:首先確定哪些規則用於您有興趣的節點。您可以使用path.rpart。然後你可以依次應用規則來提取觀察結果。

這種方法的功能:

get_node_date <- function(tree = fit, node = 5){ 
    rule <- path.rpart(tree, node) 
    rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE)) 
    ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all) 
    kyphosis[ind,] 
    } 

對於節點5你:

get_node_date() 

node number: 5 
    root 
    Start>=8.5 
    Start< 14.5 
    Kyphosis Age Number Start 
2 absent 158  3 14 
10 present 59  6 12 
11 present 82  5 14 
14 absent 1  4 12 
18 absent 175  5 13 
20 absent 27  4  9 
23 present 96  3 12 
26 absent 9  5 13 
28 absent 100  3 14 
32 absent 125  2 11 
33 absent 130  5 13 
35 absent 140  5 11 
37 absent 1  3  9 
39 absent 20  6  9 
40 present 91  5 12 
42 absent 35  3 13 
46 present 139  3 10 
48 absent 131  5 13 
50 absent 177  2 14 
51 absent 68  5 10 
57 absent 2  3 13 
59 absent 51  7  9 
60 absent 102  3 13 
66 absent 17  4 10 
68 absent 159  4 13 
69 absent 18  4 11 
71 absent 158  5 14 
72 absent 127  4 12 
74 absent 206  4 10 
77 present 157  3 13 
78 absent 26  7 13 
79 absent 120  2 13 
81 absent 36  4 13 
1

rpart包返回包含您所需要的信息rpart.object元素:

require(rpart) 
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit2 

get_node_date <-function(nodeId,fit) 
{ 
    fit$frame[toString(nodeId),"n"] 
} 


for (i in c(1,2,4,5,10,11,22,23,3)) 
    cat(get_node_date(i,fit2),"\n") 
+1

你沒有得到通過這一從而陷入一個類別 – DatamineR

+1

你是對的意見,但只有abservations的數量,誤解的問題 –

1

partykit包也爲此提供了一個解決方案罐頭。您只需將rpart對象轉換爲party類,以便使用其統一接口來處理樹。然後你可以使用data_party()函數。

使用來自問題fit並具有裝library("partykit")可以先要挾rpartparty

pfit <- as.party(fit) 
plot(pfit) 

full pfit tree

只有兩個在方式提取數據小滋擾你想要:(1)原始貼合的model.frame()總是在強制下降,需要手動重新連接。 (2)節點使用不同的編號方案。你現在想要節點4(而不是5)。

pfit$data <- model.frame(fit) 
data4 <- data_party(pfit, 4) 
dim(data4) 
## [1] 33 5 
head(data4) 
## Kyphosis Age Start (fitted) (response) 
## 2 absent 158 14  7  absent 
## 10 present 59 12  8 present 
## 11 present 82 14  8 present 
## 14 absent 1 12  5  absent 
## 18 absent 175 13  7  absent 
## 20 absent 27  9  5  absent 

另一條路線是從子集節點4開始,然後從該取數據的子樹:

pfit4 <- pfit[4] 
plot(pfit4) 

subtree of pfit from node 4

然後data_party(pfit4)給你上述相同data4。而pfit4$data爲您提供的數據沒有(fitted)節點和預測的(response)

+0

如果你使用'$ ptree中的數據< - model.frame( eval(tree $ call $ data))'公式中沒有使用的變量將不會被丟棄 – rawr

+0

True ...但是隻有'data'包含'formula'中的所有變量時,情況不一定如此。通過'model.frame()',您還可以獲得經常在運行中創建的變量變量,例如'log()','Surv()'或'factor()'版本的變量。 –

+0

順便提一句:'rpart'對象的'as.party()'強制現在默認爲data__!因此,您可以執行'as.party(fit,data = TRUE)'(這是新的默認值)或'as.party(fit,data = FALSE)'(對應於舊行爲)。 –

1

另一種方式是,通過查找任何特定節點的所有終端節點並返回呼叫中使用的數據子集來工作。

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 

head(subset.rpart(fit, 5)) 
# Kyphosis Age Number Start 
# 2 absent 158  3 14 
# 10 present 59  6 12 
# 11 present 82  5 14 
# 14 absent 1  4 12 
# 18 absent 175  5 13 
# 20 absent 27  4  9 


subset.rpart <- function(tree, node = 1L) { 
    data <- eval(tree$call$data, parent.frame(1L)) 
    wh <- sapply(as.integer(rownames(tree$frame)), parent) 
    wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)])) 
    data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ] 
} 

parent <- function(x) { 
    if (x[1] != 1) 
    c(Recall(if (x %% 2 == 0L) x/2 else (x - 1)/2), x) else x 
}