2017-05-30 64 views
1

我在Keras中實現了我自己的復發層,並且在step函數內部我希望能夠訪問所有時間步驟中的隱藏狀態,而不僅僅是最後一個狀態如默認,以便我可以做一些事情,比如及時向後添加跳過連接。返回Keras中RNN中所有時間步驟的所有狀態

我試圖修改tensorflow後端K.rnn內的_step以返回到目前爲止所有隱藏狀態。我最初的想法是簡單地將每個隱藏狀態存儲到TensorArray中,然後將所有這些狀態都傳遞給step_function(即我的層中的step函數)。我現在的修改功能;下面,寫每個隱藏狀態轉變爲TensorArray states_ta_t

def _step(time, output_ta_t, states_ta_t, *states): 
      current_input = input_ta.read(time) 
      # Here I'd like to return all states up to current time 
      # and pass to step_function, instead of just the last 
      states = [states_ta_t.read(time)] 
      output, new_states = step_function(current_input, 
               tuple(states) + 
               tuple(constants)) 
      for state, new_state in zip(states, new_states): 
       new_state.set_shape(state.get_shape()) 
      states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states 
      output_ta_t = output_ta_t.write(time, output) 
      return (time + 1, output_ta_t, states_ta_t) + tuple(new_states) 

這個版本只返回一個狀態,就像當初的實施,並可以作爲一個正常的RNN。我如何獲取迄今爲止所有的狀態,存儲在數組中,並傳遞給step_function?感覺這應該是非常簡單的,但我對TensorArrays不是很熟練......

(注意:在展開的版本中這比在符號上更容易,但不幸的是我會用完使用我的實驗展開的版本)

回答

2

記憶 - 編輯 -

我發現,我誤解你的問題,我非常抱歉爲...

總之,試試這個:

states = states_ta_t.stack()[:time] 

下面是一些說明:你確實已經將所有這些狀態存儲在states_ta_t中,但你只能通過最後一個到step_function

你已經在你的代碼做的是:

# Param 'time' refers to 'current time step' 
states = [states_ta_t.read(time)] 

這意味着,你從states_ta_t讀取「當前」的狀態,換句話說,最後的狀態。

如果你想做一些切片,也許stack功能將有所幫助。例如:

states = states_ta_t.stack()[:time] 

但我不知道這是否是一個正確的實施,因爲我不熟悉TensorArray要麼...

希望它能幫助!如果沒有,如果你願意留下評論並與我討論,這是我的榮幸!

+1

歡迎來到SO。請閱讀這個[如何回答] (http://stackoverflow.com/help/how-to-answer)提供高質量的答案。 – thewaywewere

+1

感謝您的諮詢!我會讀它,並盡力提高我的答案:) – Carefree0910

+1

非常感謝@ Carefree0910。這回答了我的問題,我不知道我可以用這種方式對它們進行切片:-)我最終意識到,我可能會以這種方式使用太多的內存,通過在state_ta_t中一次保持所有狀態。所以我最終創建了兩個TensorArrays,一個用於當前時間步和一個上一個時間步,用「clear_after_read = True」,這樣我只能訪問一個額外的狀態,但只保留兩個狀態隨時在內存中。 – jodles

相關問題