2017-09-24 43 views
1

Tensorflow新手在這裏!我知道Variables會隨着時間的推移而被訓練,佔位符會使用輸入數據,這些輸入數據在模型訓練時不會改變(如輸入圖像和這些圖像的類標籤)。tf.zeros vs tf.placeholder作爲RNN初始狀態

我正在嘗試使用Tensorflow實現RNN的正向傳播,並且想知道我應該保存哪種類型的RNN單元的輸出。在numpy的RNN實現,它使用

hiddenStates = np.zeros((T, self.hidden_dim)) #T is the length of the sequence

然後迭代地保存np.zeros陣列中的輸出。

在TF的情況下,我應該使用tf.zeros還是tf.placeholder?

這種情況下的最佳做法是什麼?我認爲使用tf.zeros應該沒問題,但是要仔細檢查。

回答

2

首先,瞭解Tensorflow內部的一切都是一個張量是很重要的。因此,當您執行某種計算時(例如,像outputs = rnn(...)這樣的實現),此計算的輸出將作爲張量返回。所以你不需要將它存儲在任何類型的結構中。您可以通過運行通訊節點(即output)(如session.run(output, feed_dict))來檢索它。

對此,我認爲你需要把RNN的最終狀態作爲後續計算的初始狀態。方法有兩種:

A)如果使用RNNCell實現在你的建設模式,你可以構建零狀態是這樣的:

cell = (some RNNCell implementation) 
initial_state = cell.zero_state(batch_size, tf.float32) 

B)如果您uimplementing自己的員工定義狀態的零張量:

initial_state = tf.zeros([batch_size, hidden_size]) 

然後,在這兩種情況下,你會碰到這樣的:

output, final_state = rnn(input, initial_state) 

在你執行循環可以先初始化狀態,然後提供您feed_dictfinal_stateinitial_state

state = session.run(initial_state) 
for step in range(epochs): 

    feed_dict = {initial_state: state} 
    _, state = session.run((train_op,final_state), feed_dict) 

你如何真正構建您feed_dict依賴於RNN的實施。

對於BasicLSTMCell,例如,一個狀態是LSTMState對象,你需要同時提供ch

feed_dict = {initial_state.c=state.c, initial_state.h: state.h}