2016-06-26 71 views
4

我想用一個attention_decoder構建一個seq2seq模型,並使用帶LSTMCell的MultiRNNCell作爲編碼器。因爲TensorFlow代碼建議「這種默認行爲(state_is_tuple = False)很快就會被棄用。」,我爲編碼器設置了state_is_tuple = True。TensorFlow attention_decoder with RNNCell(state_is_tuple = True)

的問題是,當我通過編碼器attention_decoder的狀態,它會報告錯誤:

*** AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape' 

這個問題似乎在seq2seq.py和_linear要進行相關的注意()函數()函數在rnn_cell.py中,代碼從編碼器生成的initial_state調用'LSTMStateTuple'對象的'get_shape()'函數。

雖然當我設置state_is_tuple = False的編碼錯誤消失,程序給出以下警告:

WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.LSTMCell object at 0x11763dc50>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. 

我會很感激,如果有人可以提供關於建立seq2seq與RNNCell任何指令(state_is_tuple =真正)。

回答

0

我也遇到過這個問題,lstm狀態需要連接,否則_linear會抱怨。 LSTMStateTuple的形狀取決於您使用的單元種類。隨着LSTM單元,您可以連接狀態是這樣的:

query = tf.concat(1,[state[0], state[1]]) 

如果您使用的是MultiRNNCell,串聯狀態每層第一:

concat_layers = [tf.concat(1,[c,h]) for c,h in state] 
query = tf.concat(1, concat_layers) 
相關問題