0

在Double DQN(在CNTK中實現)中,我試圖使用在線模型計算下一個狀態的值(post_state_var)。爲了矢量化我的解決方案,我使用了one_hot op。但是,當我嘗試訓練時,出現以下錯誤:從反向傳播中排除OneHot op

節點「OneHot」可用於訓練,但不參與梯度傳播。

我定義我的模型和投入:

state_var = cntk.input_variable(state_shape, name='state') 
action_var = cntk.input_variable(1, name='action') 
reward_var = cntk.input_variable(1, name='reward') 
post_state_var = cntk.input_variable(state_shape, name='post_state') 
terminal_var = cntk.input_variable(1, name='terminal') 

with cntk.default_options(activation=relu): 
    model_fn = Sequential([ 
     Dense(32, name='h1'), 
     Dense(32, name='h2'), 
     Dense(action_shape, name='action') 
    ]) 

model = model_fn(state_var) 
target_model = model.clone(cntk.CloneMethod.freeze) 

我再計算出目標值和定義損失如下:

# Value of action selected at state t 
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1) 

# Double Q learning - Value of action selected at state t+1 
online_post_state_model = model_fn(post_state_var) 
online_post_state_best_action = cntk.argmax(online_post_state_model) 
post_state_best_value = cntk.reduce_sum(target_model * 
             one_hot(online_post_state_best_action, num_classes=action_shape)) 

gamma = 0.99 
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value 

# MSE for simplicity 
td_error = state_value - cntk.stop_gradient(target) 
loss = cntk.reduce_mean(cntk.square(td_error)) 

如果我更換

online_post_state_model = model_fn(post_state_var) 

online_post_state_model = model_fn.clone(cntk.CloneMethod.freeze)(post_state_var) 

然後錯誤消失了,但這是錯誤的,因爲它使用了一箇舊的凍結模型來計算目標。如何使用post_state_var評估model_fn並排除反向傳播的輸出?我沒有正確使用stop_gradient嗎?

回答

0

one_hot的典型用途是用於您通常不需要反向傳播的輸入數據。

解決方法是將動作保持爲圖中一個熱點向量。您可以使用hardmax而不是argmax來實現。