2016-02-26 160 views
0

與此相關的複製變量:How can I copy a variable in tensorflowtensorflow在RNN

我試圖複製LSTM解碼單元的值在其他地方使用它beamsearch。在僞代碼,我想是這樣的:

lstm_decode = tf.nn.rnn_cell(...) 
training_output = tf.nn.seq2seq.rnn_decoder(...) 
... do training by back-prop the error on trainint_output ... 

# duplicate the lstm_decode unit (same weights) 
lstm_decode_copy = copy(lstm_decode) 
... do beam search with the duplicated lstm ... 

的問題是,在tensorflow,沒有召喚「tf.nn.rnn_cell(......)」過程中產生的LSTM變量,但它是實際上是在函數調用展開到rnn_decoder期間生成的。

我可以將範圍設置爲「tf.nn.seq2seq.rnn_decoder」函數調用,但lstm權重的實際初始化對我來說並不透明。我如何捕獲這些值並重新使用它們來創建一個與所學的權重相同的lstm單元?

謝謝!

回答

0

我想我想通了。

你想要的是用於解碼器呼叫範圍設置爲特定值,說「解碼」,在這一行:

training_output = tf.nn.seq2seq.rnn_decoder(...scope="decoding") 

,稍後當你想用你學到的LSTM單位在解碼期間,您將變量範圍再次設置爲「解碼」,並使用scope.reuse_variables()來允許重新使用解碼的變量。然後簡單地使用「lstm_decode」,否則將被綁定到與以前相同的值。

with tf.variable_scope("decoding") as scope: 
    scope.reuse_variables() 
    ... use lstm_decode as usual ... 

這種方式在lstm_decode所有的權重將在這兩個子圖共享,並取其值學到期間的訓練將出現在第二部分也是如此。