2017-01-01 49 views
2

我已經使用data.table實現了here所描述的一個簡單的動態編程示例,希望它能像矢量化代碼一樣快。使用循環順序更新data.table列

library(data.table) 
B=100; M=50; alpha=0.5; beta=0.9; 
n = B + M + 1 
m = M + 1 
u <- function(c)c^alpha 
dt <- data.table(s = 0:(B+M))[, .(a = 0:min(s, M)), s] # State Space and corresponging Action Space 
dt[, u := (s-a)^alpha,]        # rewards r(s, a) 
dt <- dt[, .(s_next = a:(a+B), u = u), .(s, a)]  # all possible (s') for each (s, a) 
dt[, p := 1/(B+1), s]         # transition probs 

#   s a s_next u  p 
#  1: 0 0  0 0 0.009901 
#  2: 0 0  1 0 0.009901 
#  3: 0 0  2 0 0.009901 
#  4: 0 0  3 0 0.009901 
#  5: 0 0  4 0 0.009901 
# ---       
#649022: 150 50 146 10 0.009901 
#649023: 150 50 147 10 0.009901 
#649024: 150 50 148 10 0.009901 
#649025: 150 50 149 10 0.009901 
#649026: 150 50 150 10 0.009901 

給一點內容,我的問題:在sass_next)未來值的條件被實現爲a:(a+10)一個,每個概率p=1/(B + 1)u列給出了每種組合(s, a)u(s, a)

  • 給定初始值V(總是n by 1向量)爲每個唯一的狀態sV根據V[s]=max(u(s, a)) + beta* sum(p*v(s_next))(貝爾曼方程)的更新。
  • 最大化爲a,因此,[, `:=`(v = max(v), i = s_next[which.max(v)]), by = .(s)]在下面的迭代中。

實際上有非常高效的vectorized solution。我認爲data.table解決方案的性能與矢量化方法相當。

我知道主要罪魁禍首是dt[, v := V[s_next + 1]]。唉,我不知道如何解決它。

# Iteration starts here 
system.time({ 
    V <- rep(0, n) # initial guess for Value function 
    i <- 1 
    tol <- 1 
    while(tol > 0.0001){ 
    dt[, v := V[s_next + 1]] 
    dt[, v := u + beta * sum(p*v), by = .(s, a) 
     ][, `:=`(v = max(v), i = s_next[which.max(v)]), by = .(s)] # Iteration 
    dt1 <- dt[, .(v[1L], i[1L]), by = s] 
    Vnew <- dt1$V1   
    sig <- dt1$V2 
    tol <- max(abs(V - Vnew)) 
    V <- Vnew 
    i <- i + 1 
    }  
}) 
# user system elapsed 
# 5.81 0.40 6.25 

令我沮喪的是,data.table解決方案甚至比以下非矢量化解決方案還要慢。作爲一個馬虎data.table用戶,我必須錯過data.table功能。有沒有辦法改善事情,或者,data.table不適合這種計算?

S <- 0:(n-1) # StateSpace 
VFI <- function(V){ 
    out <- rep(0, length(V)) 
    for(s in S){ 
    x <- -Inf 
    for(a in 0:min(s, M)){ 
     s_next <- a:(a+B)  # (s') 
     x <- max(x, u(s-a) + beta * sum(V[s_next + 1]/(B+1))) 
    } 
    out[s+1] <- x 
    } 
    out 
} 
system.time({ 
V <- rep(0, n) # initial guess for Value function 
i <- 1 
tol <- 1 
while(tol > 0.0001){ 
    Vnew <- VFI(V)   
    tol <- max(abs(V - Vnew)) 
    V <- Vnew 
    i <- i + 1 
}  
}) 
# user system elapsed 
# 3.81 0.00 3.81 
+2

請參閱https://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example。有人可能會抽出時間來解決這個問題,但減少到最簡單的問題演示(在你的情況下,使用data.table緩慢)將會得到更好的結果。 –

+8

@JackWasey你真的有一些神經。你真的認爲這個鏈接是需要的嗎?我認爲Khashaa知道r/data.table沒有比你更糟,並知道如何創建MWE。如果你不能幫助,你可以繼續前進 - 不需要自負的評論。 –

+4

如果你的問題的主要目標是如何提高data.table方法的性能,那麼也許別人可以提供幫助。但是,如果你只是在尋找改善這些動態模型性能的方法,那麼我個人總是使用RCpp來處理這種事情。向量化動態模型通常很棘手,而且通常不可能。如果需要速度,RCpp通常是最好的選擇。 – dww

回答

2

這是我會怎麼做這個...

DT = CJ(s = seq_len(n)-1L, a = seq_len(m)-1L, s_next = seq_len(n)-1L) 
DT[ , p := 0] 
#p is 0 unless this is true 
DT[between(s_next, a, a + B), p := 1/(B+1)] 
#may as well subset to eliminate irrelevant states 
DT = DT[p>0 & s>=a] 
DT[ , util := u(s - a)] 

#don't technically need by, but just to be careful 
DT[ , V0 := rep(0, n), by = .(a, s_next)] 

while(TRUE) { 
    #for each s, maximize given past value; 
    # within each s, have to sum over s_nexts, 
    # to do so, sum by a 
    DT[ , V1 := max(.SD[ , util[1L] + beta*sum(V0*p), by = a], 
       na.rm = TRUE), by = s] 
    if (DT[ , max(abs(V0 - V1))] < 1e-4) break 
    DT[ , V0 := V1] 
} 

在我的機器大約需要15秒(所以不太好)......但也許這會給你一些想法。例如,此data.table太大,因爲最終只有n的唯一值V