2016-09-16 44 views
4

我想比較模型性能的一堆模型使用相同的預測,但不同的模型參數。這似乎是使用broom來創建整齊輸出的地方,但我無法弄清楚。 下面是一些非工作的代碼,可以幫助有什麼建議我在想:整潔的方法測試模型參數

seq(1:10) %>% 
do(fit = knn(train_Market, test_Market, train_Direction, k=.), score = mean(fit==test_Direction)) %>% 
tidy() 

更多情況下,這是ISLR實驗室,我們正試圖tidyverse-IFY的一個組成部分。在這裏可以看到整個實驗室:https://github.com/AmeliaMN/tidy-islr/blob/master/lab3/lab3.Rmd

[更新:可重複的例子]這是很難做出一個小例子,在這裏,因爲需要對數據進行模型擬合之前鬧得不可開交,但這應該是可重複:

library(ISLR) 
library(dplyr) 

train = Smarket %>% 
    filter(Year < 2005) 
test = Smarket %>% 
    filter(Year >= 2005) 

train_Market = train %>% 
    select(Lag1, Lag2) 
test_Market = test %>% 
    select(Lag1, Lag2) 

train_Direction = train %>% 
    select(Direction) %>% 
    .$Direction 

set.seed(1) 
knn_pred = knn(train_Market, test_Market, train_Direction, k=1) 
mean(knn_pred==test_Direction) 

knn_pred = knn(train_Market, test_Market, train_Direction, k=3) 
mean(knn_pred==test_Direction) 

knn_pred = knn(train_Market, test_Market, train_Direction, k=4) 
mean(knn_pred==test_Direction) 

+0

謝謝你讓我說實話,亞歷克斯。 – AmeliaMN

+0

你想堅持dplyr /'do'嗎?這似乎很適合list-loops la lapply或purrr的功能。 – aosmith

+0

對不起Amelia。只是我正在瀏覽鏈接的文字,並且失去了我的注意力。 –

回答

3

由於你的每一個KNN(和Oracle)的輸出是一個向量,這是一個很好的案例tidyr的unnest(結合purrr的maprep_along

library(class) 
library(purrr) 
library(tidyr) 
set.seed(1) 

predictions <- data_frame(k = 1:5) %>% 
    unnest(prediction = map(k, ~ knn(train_Market, test_Market, train_Direction, k = .))) %>% 
    mutate(oracle = rep_along(prediction, test_Direction)) 

predictions變量隨後組織爲:

# A tibble: 1,260 x 3 
     k prediction oracle 
    <int>  <fctr> <fctr> 
1  1   Up  Up 
2  1  Down  Up 
3  1   Up Down 
4  1   Up  Up 
5  1   Up  Up 
6  1  Down  Up 
7  1  Down Down 
8  1  Down  Up 
9  1  Down  Up 
10  1   Up  Up 
# ... with 1,250 more rows 

它可以很容易地概括爲:

predictions %>% 
    group_by(k) %>% 
    summarize(accuracy = mean(prediction == oracle)) 

再次,因爲每個輸出是一個因素,你不需要掃帚,但如果它是一個模型,你可以使用掃帚的tidyaugment,然後以相似的方式不加思索。這種方法的


一個重要方面是,它是靈活的參數的多種組合,通過與tidyr的crossing(或expand.grid)相結合,並使用invoke_rows到功能應用到每一行。例如,你可以嘗試的l變化旁邊k

crossing(k = 2:5, l = 0:1) %>% 
    invoke_rows(knn, ., train = train_Market, test = test_Market, cl = train_Direction) %>% 
    unnest(prediction = .out) %>% 
    mutate(oracle = rep_along(prediction, test_Direction)) %>% 
    group_by(k, l) %>% 
    summarize(accuracy = mean(prediction == oracle)) 

這將返回:

Source: local data frame [8 x 3] 
Groups: k [?] 

     k  l accuracy 
    <int> <int>  <dbl> 
1  2  0 0.5396825 
2  2  1 0.5277778 
3  3  0 0.5317460 
4  3  1 0.5317460 
5  4  0 0.5277778 
6  4  1 0.5357143 
7  5  0 0.4841270 
8  5  1 0.4841270