2017-08-01 141 views
1

我正在構建用於語言識別的statefull LSTM。 正在有條件的我可以用更小的文件來訓練網絡,並且新的批處理將會像討論中的下一句話一樣。 但是,要正確訓練網絡,我需要重置一些批次之間的LSTM的隱藏狀態。Tensorflow RNN-LSTM - 重置隱藏狀態

我使用一個變量來存儲LSTM的hidden_​​state性能:

with tf.variable_scope('Hidden_state'): 
     hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size], 
             tf.float32, initializer=tf.constant_initializer(0.0), trainable=False) 
     # Arrange it to a tuple of LSTMStateTuple as needed 
     l = tf.unstack(hidden_state, axis=0) 
     rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1]) 
           for idx in range(self.num_layers)]) 

    # Build the RNN 
    with tf.name_scope('LSTM'): 
     rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths, 
              initial_state=rnn_tuple_state, time_major=True) 

現在我對如何重置隱藏狀態混亂。我已經嘗試了兩種解決方案,但它不工作:

首先解決

重置與「hidden_​​state」變量:

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state)) 

它不工作,我想這是因爲拆散和元組在運行rnn_state_zero_op操作後,構造不會「重新播放」到圖中。

解決方法二

LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow我試着細胞狀態重置:

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32) 

它似乎沒有任何工作。

問題

我在心中另一種解決方案,但它在猜測最好的:我沒有保持由tf.nn.dynamic_rnn返回的狀態,我已經想到這一點,但我得到一個元組我無法找到一種方法來構建重置元組的操作。

在這一點上,我必須承認,我不太瞭解tensorflow的內部工作,如果甚至有可能做我想做的事情。 有沒有適當的方法來做到這一點?

謝謝!

回答

2

感謝this answer to another question我能找到一種方法,對RNN的內部狀態是否(以及何時)應該被重置爲0

首先完全控制你需要定義一些變量存儲RNN的狀態,這樣你就擁有控制權:

with tf.variable_scope('Hidden_state'): 
    state_variables = [] 
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32): 
     state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
      tf.Variable(state_c, trainable=False), 
      tf.Variable(state_h, trainable=False))) 
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state 
    rnn_tuple_state = tuple(state_variables) 

請注意,此版本直接定義由LSTM使用的變量,因爲你不這是我的問題比版本好得多」不得不拆開並構建元組,這會爲圖形添加一些操作,使其無法顯式運行。

其次建立RNN和檢索最終狀態:所以現在你有RNN的新的內部狀態

# Build the RNN 
with tf.name_scope('LSTM'): 
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs, 
               sequence_length=input_seq_lengths, 
               initial_state=rnn_tuple_state, 
               time_major=True) 

。您可以定義兩個操作來管理它。

第一個將更新下一批次的變量。

# Define an op to keep the hidden state between batches 
update_ops = [] 
for state_variable, new_state in zip(rnn_tuple_state, new_states): 
    # Assign the new state to the state variables on this layer 
    update_ops.extend([state_variable[0].assign(new_state[0]), 
         state_variable[1].assign(new_state[1])]) 
# Return a tuple in order to combine all update_ops into a single operation. 
# The tuple's actual value should not be used. 
rnn_keep_state_op = tf.tuple(update_ops) 

你應該要運行一個批處理這個運算添加到您的會話隨時隨地保持內部:那麼在下一批次的「initial_state」的RNN將與前一批次的最終狀態下供給州。

請注意:如果您運行帶有此op的批處理1,則批2將以批1的最終狀態開始,但如果在運行批2時不再調用它,則批3將以批1開始最終狀態也。我的建議是每次運行RNN時添加此操作。

第二OP將用於該RNN的內部狀態重置爲零:

# Define an op to reset the hidden state to zeros 
update_ops = [] 
for state_variable in rnn_tuple_state: 
    # Assign the new state to the state variables on this layer 
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])), 
         state_variable[1].assign(tf.zeros_like(state_variable[1]))]) 
# Return a tuple in order to combine all update_ops into a single operation. 
# The tuple's actual value should not be used. 
rnn_state_zero_op = tf.tuple(update_ops) 

,每當你想重置內部狀態,您可以調用這個運算。

0

簡化AMairesse後的版本爲一個LSTM層:

zero_state = tf.zeros(shape=[1, units[-1]]) 
self.c_state = tf.Variable(zero_state, trainable=False) 
self.h_state = tf.Variable(zero_state, trainable=False) 
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state) 

self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder) 

# save or reset states 
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)] 
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)] 

,也可以使用替代init_encoder在步驟== 0復位狀態(需要傳遞到self.step_tf session.run()作爲佔位符):

self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step") 

self.init_encoder = tf.cond(tf.equal(self.step_tf, 0), 
    true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state), 
    false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))