我在Keras中實現了我自己的復發層,並且在step
函數內部我希望能夠訪問所有時間步驟中的隱藏狀態,而不僅僅是最後一個狀態如默認,以便我可以做一些事情,比如及時向後添加跳過連接。返回Keras中RNN中所有時間步驟的所有狀態
我試圖修改tensorflow後端K.rnn
內的_step
以返回到目前爲止所有隱藏狀態。我最初的想法是簡單地將每個隱藏狀態存儲到TensorArray中,然後將所有這些狀態都傳遞給step_function
(即我的層中的step
函數)。我現在的修改功能;下面,寫每個隱藏狀態轉變爲TensorArray states_ta_t
:
def _step(time, output_ta_t, states_ta_t, *states):
current_input = input_ta.read(time)
# Here I'd like to return all states up to current time
# and pass to step_function, instead of just the last
states = [states_ta_t.read(time)]
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t, states_ta_t) + tuple(new_states)
這個版本只返回一個狀態,就像當初的實施,並可以作爲一個正常的RNN。我如何獲取迄今爲止所有的狀態,存儲在數組中,並傳遞給step_function
?感覺這應該是非常簡單的,但我對TensorArrays不是很熟練......
(注意:在展開的版本中這比在符號上更容易,但不幸的是我會用完使用我的實驗展開的版本)
歡迎來到SO。請閱讀這個[如何回答] (http://stackoverflow.com/help/how-to-answer)提供高質量的答案。 – thewaywewere
感謝您的諮詢!我會讀它,並盡力提高我的答案:) – Carefree0910
非常感謝@ Carefree0910。這回答了我的問題,我不知道我可以用這種方式對它們進行切片:-)我最終意識到,我可能會以這種方式使用太多的內存,通過在state_ta_t中一次保持所有狀態。所以我最終創建了兩個TensorArrays,一個用於當前時間步和一個上一個時間步,用「clear_after_read = True」,這樣我只能訪問一個額外的狀態,但只保留兩個狀態隨時在內存中。 – jodles