2016-05-03 97 views
0

我形成了一個有幾個隱藏狀態的小型LSTM單元。從Tensorflow howtos中,我可以保存和恢復用tf.Variable聲明的變量的狀態。然而,當我調查了rnn_cell.py,我看到了存在的功能:tensorflow保存庫中的共享變量

def linear(args, output_size, bias, bias_start=0.0, scope=None): 

而且裏面有一個共享變量的訪問

matrix = vs.get_variable("Matrix", [total_arg_size, output_size])

據我瞭解矩陣店權重W_i,W_o,W_f和W_o,因爲畢竟線性函數,來了:

new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j) 
new_h = tanh(new_c) * sigmoid(o) 

所以,我願意保存和恢復該變量 好。我的問題是這可能在哪裏?

+0

你爲什麼要這樣做?當前的保存/恢復不適用於您? –

+0

1 - 我從另一篇文章中讀到的是,保存和恢復不保存這些變量(只保存你在自己的代碼中定義的變量,也許這是錯誤的,我沒有仔細檢查)。 2-對於速度:我有數百個訓練數據文件,當我訓練這些文件時,我在文件的開頭和結尾恢復並保存狀態。因此,在批量處理單個文件的過程中,對於某些自定義計算,我需要在每批(每256個元素可以說)的Matrix值。保存和恢復每256個元素,我認爲可能是昂貴的(讀/寫io)。 – flyingmadden

回答

0

對於記錄,可以通過深入到變量範圍來獲取矩陣。 get_variable要求昏暗的信息,以及:[2 * hidden_size, 4 * hidden_size]

 with tf.variable_scope("RNN", reuse=True): 
      with tf.variable_scope("BasicLSTMCell", reuse=True): 
      with tf.variable_scope("Linear", reuse=True): 
       v1 = tf.get_variable("Matrix", [2 * hidden_size, 4 * hidden_size]) 
       print(v1.eval()) 
0

您可以通過評估它訪問您的張量。例如要得到matrix的值,你應該評估它,並通過破壞如下: ar = sess.run(matrix) for row in ar: for col in row: # your method to save your data ,你可以建立一個類,其中你的變量在這裏用作佔位符,你只需要用已經保存的已加載模型提供它們!

+0

是真的,我在會話中訪問它們 – flyingmadden