2016-09-27 83 views
10

如何從TensorFlow中的tf.nn.rnn()tf.nn.dynamic_rnn()獲取所有隱藏狀態? API只給我最後的狀態。TensorFlow:從RNN獲取所有狀態

第一種方法是在構建直接在RNNCell上運行的模型時編寫循環。但是,時間步數對我來說並不固定,並且取決於傳入的批次。

一些選項是使用GRU或編寫我自己的將狀態連接到輸出的RNNCell。前者的選擇不夠普遍,後者聽起來太冒險。

另一種選擇是做類似the answers in this question的事情,從RNN獲取所有變量。但是,我不確定如何以標準方式將隱藏狀態與其他變量分開。

在仍使用庫提供的RNN API時,是否有一種很好的方法可以從RNN獲取所有隱藏狀態?

+0

我創建了一個PR [這裏](https://github.com/tensorflow/tensorflow/pull/9995),它可能會幫助你處理簡單的案例 – Carefree0910

回答

0

我已經創建了一個PR here並且它可以幫助你處理簡單的案件

讓我簡單介紹一下我的實現,所以你可以,如果你需要編寫自己的版本。主要的部分是_time_step功能的修改:

def _time_step(time, output_ta_t, state, *args): 

的參數保持不變,除了額外的*args傳入但是爲什麼args?這是因爲我想支持張量流傳統的習慣行爲。您可以只通過簡單地忽略args參數返回最終狀態:

if states_ta is not None: 
    # If you want to return all states, set `args` to be `states_ta` 
    loop_vars = (time, output_ta, state, states_ta) 
else: 
    # If you want the final state only, ignore `args` 
    loop_vars = (time, output_ta, state) 

如何使用它?

if args: 
    args = tuple(
     ta.write(time, out) for ta, out in zip(args[0], [new_state]) 
    ) 

其實這只是一個以下(原件)代碼修改:

output_ta_t = tuple(
    ta.write(time, out) for ta, out in zip(output_ta_t, output) 
) 

現在args應該包含所有你想要的狀態。

之後所有的作品上面做,你可以用以下代碼拿起狀態(或最終狀態):

_, output_final_ta, *state_info = control_flow_ops.while_loop(... 

if states_ta is not None: 
    final_state, states_final_ta = state_info 
else: 
    final_state, states_final_ta = state_info[0], None 

雖然我還沒有測試它它應該在'簡單'條件下工作(here's我的測試用例)

+0

感謝您花時間撰寫答案。在回答你的第一句話時,最好不要在Stack Overflow上有重複的信息。一旦你獲得了75的聲望,你就可以將一個問題標記爲另一個問題的複製品(雖然我可能是錯的,也許你現在可以這樣做)。如果問題不一樣,最好定製每個答案以適應問題的需求。 – zondo

+0

感謝您的評論!我已經發現這兩個問題之間的一些差異所以現在我也跟着你的建議和定製的每個答案:) – Carefree0910

+0

我其實解決了,這是通過創建一個封裝器單元(如MultiRNNCell),其輸出與輸出連接在一起的狀態的方式。在此之後,只需要進行拆分即可將輸出與隱藏狀態分開。 –

2

tf.nn.dynamic_rnn(也有tf.nn.static_rnn)有兩個返回值;如您所說,「狀態」是RNN的最終狀態,但「輸出」是RNN的所有隱藏狀態(該形狀是[batch_size,max_time,cell.output_size ])

您可以使用「outputs」作爲RNN的隱藏狀態,因爲在大多數庫提供的RNNCell中,「output」和「state」是相同的。(除LSTMCell)

+0

撇開這是特定於GRU,這並不能幫助你,如果你有多個層,例如,如果你換行[GRUCell](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/GRUCell)在[MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/MultiRNNCell)中。你的輸出將只包含來自最後一層的狀態。 –