我正在構建用於語言識別的statefull LSTM。 正在有條件的我可以用更小的文件來訓練網絡,並且新的批處理將會像討論中的下一句話一樣。 但是,要正確訓練網絡,我需要重置一些批次之間的LSTM的隱藏狀態。Tensorflow RNN-LSTM - 重置隱藏狀態
我使用一個變量來存儲LSTM的hidden_state性能:
with tf.variable_scope('Hidden_state'):
hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
# Arrange it to a tuple of LSTMStateTuple as needed
l = tf.unstack(hidden_state, axis=0)
rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
for idx in range(self.num_layers)])
# Build the RNN
with tf.name_scope('LSTM'):
rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
initial_state=rnn_tuple_state, time_major=True)
現在我對如何重置隱藏狀態混亂。我已經嘗試了兩種解決方案,但它不工作:
首先解決
重置與「hidden_state」變量:
rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))
它不工作,我想這是因爲拆散和元組在運行rnn_state_zero_op操作後,構造不會「重新播放」到圖中。
解決方法二
繼LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow我試着細胞狀態重置:
rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)
它似乎沒有任何工作。
問題
我在心中另一種解決方案,但它在猜測最好的:我沒有保持由tf.nn.dynamic_rnn返回的狀態,我已經想到這一點,但我得到一個元組我無法找到一種方法來構建重置元組的操作。
在這一點上,我必須承認,我不太瞭解tensorflow的內部工作,如果甚至有可能做我想做的事情。 有沒有適當的方法來做到這一點?
謝謝!