2016-11-28 106 views
3

我想找到邊界決策函數來分類我的數據。這是他們的一個例子。使用knn分類器的邊界決策

"Distance","Dihedral","Categ" 
    4.083,82.267,C 
    4.132,87.073,C 
    4.713,-80.999,C 
    3.427,-48.144,NC 
    3.663,96.994,C 
    3.99,71.919,C 
    3.484,78.684,C 

到目前爲止,我有knn模型,但我想繪製非線性決策邊界。在我搜索的例子中,有一些變量我不知道在哪裏使用它們或者它們是什麼意思。我正在談論我在「統計學習元素」一書中找到的這個例子

library(ElemStatLearn) 
require(class) 
x <- mixture.example$x 
g <- mixture.example$y 
xnew <- mixture.example$xnew 
mod15 <- knn(x, xnew, g, k=15, prob=TRUE) 
prob <- attr(mod15, "prob") 
prob <- ifelse(mod15=="1", prob, 1-prob) 
px1 <- mixture.example$px1 
px2 <- mixture.example$px2 
prob15 <- matrix(prob, length(px1), length(px2)) 
par(mar=rep(2,4)) 
contour(px1, px2, prob15, levels=0.5, labels="", xlab="", ylab="", main= 
     "15-nearest neighbour", axes=FALSE) 
points(x, col=ifelse(g==1, "coral", "cornflowerblue")) 
gd <- expand.grid(x=px1, y=px2) 
points(gd, pch=".", cex=1.2, col=ifelse(prob15>0.5, "coral", "cornflowerblue")) 
box() 

px1和px2究竟是什麼?我是否需要類似的變量來處理特定情況?

非常感謝您的幫助!

+0

我認爲PX1和PX2只是描述用於新的數據網格的載體,即沿其中有新數據x和y軸的點。 – Andrie

回答

1

我已重寫並註釋了示例,以說明發生了什麼。

該示例構造了一個測試集,它只是一個擴展網格,可以轉換整個測試集。因此,px1是描述測試數據網格x分量的向量,px2與y相似。然後xnewexpand.grid()的結果。

請嘗試下面的代碼,在這裏應該合理清楚。我還修改了k值,並提供了一種使用您選擇的時間間隔構造xnew的簡單方法。

library(ElemStatLearn) 
require(class) 

# Use the training data from mixture.example 
x <- mixture.example$x 
g <- mixture.example$y 

# Construct a test grid using the extent of the training data 
xx_range <- round(range(x[, 1]), 1) 
xy_range <- round(range(x[, 2]), 1) 

nnn <- 0.1 
px1 <- seq(xx_range[1], xx_range[2], by = nnn) # vector with x extent 
px2 <- seq(xy_range[1], xy_range[2], by = nnn) # vector with y extent 
xnew <- as.matrix(expand.grid(px1, px2))  # matrix of new values 

# Train a model 
k <- 10 
mod15 <- knn(x, xnew, g, k=k, prob=TRUE) 
prob <- attr(mod15, "prob") 
prob <- ifelse(mod15=="1", prob, 1-prob) 
prob15 <- matrix(prob, length(px1), length(px2)) 

# Plot the results 
par(mar=rep(2,4)) 
contour(px1, px2, prob15, levels=0.5, labels="", xlab="", ylab="", main= 
      sprintf("%d-nearest neighbour", k), axes=FALSE) 
points(x, col=ifelse(g==1, "coral", "cornflowerblue")) 
points(xnew, pch=".", cex=1.2, col=ifelse(prob15>0.5, "coral", "cornflowerblue")) 
box() 

enter image description here