2017-06-22 49 views
2

我的理解是,tf.nn.dynamic_rnn返回在每個時間步的RNN單元(例如LSTM)的輸出以及最終狀態。我如何在所有時間步驟中訪問單元狀態,而不僅僅是最後一步?例如,我希望能夠平均所有隱藏狀態,然後在後續圖層中使用它。Tensorflow,如何訪問一個RNN的所有中間狀態,而不僅僅是最後一個狀態

以下是我如何定義LSTM單元,然後使用tf.nn.dynamic_rnn展開它。但是這隻給出了LSTM的最後一個單元狀態。

import tensorflow as tf 
import numpy as np 

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8) 
X[1,6:] = 0 
X_lengths = [10, 6] 

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True) 

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell, 
    dtype=tf.float64, 
    sequence_length=X_lengths, 
    inputs=X) 
sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())         
out, last = sess.run([outputs, last_state], feed_dict=None) 
+0

沒有理由需要訪問不屬於輸出結果的內部狀態。如果這是您的使用案例,我會考慮定義一個與LSTM相同的RNN,但輸出其完整狀態。 – jasekp

+0

看看這個QA:https://stackoverflow.com/q/39716241/4282745 – npf

+0

或這個https://github.com/tensorflow/tensorflow/issues/5731#issuecomment-262151359 – npf

回答

1

這樣的事情應該工作。

import tensorflow as tf 
import numpy as np 


class CustomRNN(tf.contrib.rnn.LSTMCell): 
    def __init__(self, *args, **kwargs): 
     kwargs['state_is_tuple'] = False # force the use of a concatenated state. 
     returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell 
     self._output_size = self._state_size # change the output size to the state size 
     return returns 
    def __call__(self, inputs, state): 
     output, next_state = super(CustomRNN, self).__call__(inputs, state) 
     return next_state, next_state # return two copies of the state, instead of the output and the state 

X = np.random.randn(2, 10, 8) 
X[1,6:] = 0 
X_lengths = [10, 10] 

cell = CustomRNN(num_units=64) 

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell, 
    dtype=tf.float64, 
    sequence_length=X_lengths, 
    inputs=X) 

sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())         
states, last_state = sess.run([outputs, last_states], feed_dict=None) 

這使用連接狀態,因爲我不知道是否可以存儲任意數量的元組狀態。狀態變量具有形狀(batch_size,max_time_size,state_size)。

+0

您可否詳細介紹一下CustomRNN代碼如何返回中間狀態?我正試圖理解你的代碼! – CentAu

+1

LSTM狀態是輸出(m)和隱藏狀態(c)的組合。這段代碼取出輸出(m)並用連接狀態(c + m)替換它。忽略批量大小,輸出是[(c1 + m1),(c2 + m2),...]的列表,而不是[m1,m2,...]。 – jasekp

+0

因此,這會用隱藏狀態(c)替換實際的輸出(m),正確的('return next_state,next_state',而不是'return m,new_state')?你在哪裏連接輸出和隱藏狀態('m + c')? – CentAu

1

我想指出你這個thread(從我的亮點):

你可以寫一個返回兩個狀態張量作爲輸出的一部分LSTMCell的變體,如果你既需要Ç和每個時間步的h狀態。如果您只需要h狀態,那就是每個時間步的輸出。

由於@jasekp在其評論中寫道,輸出結果真的是h部分狀態。然後dynamic_rnn方法只會疊加無論在時間的h部分(見_dynamic_rnn_loopthis file字符串DOC):

def _dynamic_rnn_loop(cell, 
         inputs, 
         initial_state, 
         parallel_iterations, 
         swap_memory, 
         sequence_length=None, 
         dtype=None): 
    """Internal implementation of Dynamic RNN. 
    [...] 
    Returns: 
    Tuple `(final_outputs, final_state)`. 
    final_outputs: 
     A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 
     `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 
     objects, then this returns a (possibly nsted) tuple of Tensors matching 
     the corresponding shapes. 
+0

LSTMCell只是一個單元如果我沒有弄錯,則返回狀態和輸出。我認爲'tf.nn.dynamic_rnn'的展開部分只返回最後一步。所以,我需要修改它?奇怪的是,還沒有一個更高層次的解決方案。 – CentAu

相關問題