首先,瞭解Tensorflow內部的一切都是一個張量是很重要的。因此,當您執行某種計算時(例如,像outputs = rnn(...)
這樣的實現),此計算的輸出將作爲張量返回。所以你不需要將它存儲在任何類型的結構中。您可以通過運行通訊節點(即output
)(如session.run(output, feed_dict)
)來檢索它。
對此,我認爲你需要把RNN的最終狀態作爲後續計算的初始狀態。方法有兩種:
A)如果使用RNNCell
實現在你的建設模式,你可以構建零狀態是這樣的:
cell = (some RNNCell implementation)
initial_state = cell.zero_state(batch_size, tf.float32)
B)如果您uimplementing自己的員工定義狀態的零張量:
initial_state = tf.zeros([batch_size, hidden_size])
然後,在這兩種情況下,你會碰到這樣的:
output, final_state = rnn(input, initial_state)
在你執行循環可以先初始化狀態,然後提供您feed_dict
內final_state
爲initial_state
:
state = session.run(initial_state)
for step in range(epochs):
feed_dict = {initial_state: state}
_, state = session.run((train_op,final_state), feed_dict)
你如何真正構建您feed_dict
依賴於RNN的實施。
對於BasicLSTMCell
,例如,一個狀態是LSTMState
對象,你需要同時提供c
和h
:
feed_dict = {initial_state.c=state.c, initial_state.h: state.h}