0
在Double DQN(在CNTK中實現)中,我試圖使用在線模型計算下一個狀態的值(post_state_var)。爲了矢量化我的解決方案,我使用了one_hot op。但是,當我嘗試訓練時,出現以下錯誤:從反向傳播中排除OneHot op
節點「OneHot」可用於訓練,但不參與梯度傳播。
我定義我的模型和投入:
state_var = cntk.input_variable(state_shape, name='state')
action_var = cntk.input_variable(1, name='action')
reward_var = cntk.input_variable(1, name='reward')
post_state_var = cntk.input_variable(state_shape, name='post_state')
terminal_var = cntk.input_variable(1, name='terminal')
with cntk.default_options(activation=relu):
model_fn = Sequential([
Dense(32, name='h1'),
Dense(32, name='h2'),
Dense(action_shape, name='action')
])
model = model_fn(state_var)
target_model = model.clone(cntk.CloneMethod.freeze)
我再計算出目標值和定義損失如下:
# Value of action selected at state t
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1)
# Double Q learning - Value of action selected at state t+1
online_post_state_model = model_fn(post_state_var)
online_post_state_best_action = cntk.argmax(online_post_state_model)
post_state_best_value = cntk.reduce_sum(target_model *
one_hot(online_post_state_best_action, num_classes=action_shape))
gamma = 0.99
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value
# MSE for simplicity
td_error = state_value - cntk.stop_gradient(target)
loss = cntk.reduce_mean(cntk.square(td_error))
如果我更換
online_post_state_model = model_fn(post_state_var)
與
online_post_state_model = model_fn.clone(cntk.CloneMethod.freeze)(post_state_var)
然後錯誤消失了,但這是錯誤的,因爲它使用了一箇舊的凍結模型來計算目標。如何使用post_state_var
評估model_fn
並排除反向傳播的輸出?我沒有正確使用stop_gradient
嗎?