2016-08-21 125 views
5

我想在使用minibatches的Tensorflow中訓練LSTM,但是訓練完成後我想通過一次提交一個示例來使用模型。我可以在Tensorflow中設置圖形來訓練我的LSTM網絡,但是我不能以我想要的方式使用訓練後的結果。Tensorflow:使用在另一個模型中訓練的權重,不同的模型

設置代碼看起來是這樣的:

#Build the LSTM model. 
cellRaw = rnn_cell.BasicLSTMCell(LAYER_SIZE) 
cellRaw = rnn_cell.MultiRNNCell([cellRaw] * NUM_LAYERS) 

cell = rnn_cell.DropoutWrapper(cellRaw, output_keep_prob = 0.25) 

input_data = tf.placeholder(dtype=tf.float32, shape=[SEQ_LENGTH, None, 3]) 
target_data = tf.placeholder(dtype=tf.float32, shape=[SEQ_LENGTH, None]) 
initial_state = cell.zero_state(batch_size=BATCH_SIZE, dtype=tf.float32) 

with tf.variable_scope('rnnlm'): 
    output_w = tf.get_variable("output_w", [LAYER_SIZE, 6]) 
    output_b = tf.get_variable("output_b", [6]) 

outputs, final_state = seq2seq.rnn_decoder(input_list, initial_state, cell, loop_function=None, scope='rnnlm') 
output = tf.reshape(tf.concat(1, outputs), [-1, LAYER_SIZE]) 
output = tf.nn.xw_plus_b(output, output_w, output_b) 

...注意兩個佔位符,input_data和target_data。包括優化器設置在內我沒有打擾過。訓練完成和培訓會議結束後,我想建立一個使用訓練的LSTM網絡的輸入是一個完全不同的佔位符提供了一個新的會話,是這樣的:

with tf.Session() as sess: 
with tf.variable_scope("simulation", reuse=None): 
    cellSim = cellRaw 
    input_data_sim = tf.placeholder(dtype=tf.float32, shape=[1, 1, 3]) 
    initial_state_sim = cell.zero_state(batch_size=1, dtype=tf.float32) 
    input_list_sim = tf.unpack(input_data_sim) 

    outputsSim, final_state_sim = seq2seq.rnn_decoder(input_list_sim, initial_state_sim, cellSim, loop_function=None, scope='rnnlm') 
    outputSim = tf.reshape(tf.concat(1, outputsSim), [-1, LAYER_SIZE]) 

    with tf.variable_scope('rnnlm'): 
     output_w = tf.get_variable("output_w", [LAYER_SIZE, nOut]) 
     output_b = tf.get_variable("output_b", [nOut]) 

    outputSim = tf.nn.xw_plus_b(outputSim, output_w, output_b) 

這第二部分返回以下錯誤:

tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float 
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

...可能是因爲我使用的圖形仍然有舊的訓練佔位符附加到訓練的LSTM節點。 「提取」訓練的LSTM並將其放入具有不同風格輸入的新的,不同圖形的正確方法是什麼? Tensorflow似乎具有類似的東西,但示例in the documentation都提到將變量作用域用作管理變量名的一種方式,以便同一段代碼可以在同一圖中生成相似的子圖。 '重用'功能似乎與我想要的功能非常接近,但是我並沒有發現上面鏈接的Tensorflow文檔在它的功能上很清楚。自己不能給予一個名字的細胞(換句話說,

cellRaw = rnn_cell.MultiRNNCell([cellRaw] * NUM_LAYERS, name="multicell") 

是無效的),雖然我可以給一個名字到seq2seq.rnn_decoder(),我大概就無法刪除rnn_cell.DropoutWrapper()如果我使用該節點不變。

問題:

什麼是從一個圖表移動訓練有素LSTM權到另一個的正確方法?

它是正確的說,開始一個新的會話「釋放資源」,但不會刪除內置內存中的圖形?

在我看來,像「重複使用」功能允許Tensorflow爲具有相同名稱(存在於不同的範圍內)變量的當前可變範圍之外進行搜索,並在當前範圍內使用它們。它是否正確?如果是這樣,所有來自非當前範圍的鏈接到該變量的圖邊都會發生什麼?如果不是,那麼爲什麼Tensorflow會在兩個不同範圍內嘗試使用相同的變量名稱時會拋出錯誤?在兩個不同的範圍內定義具有相同名稱的兩個變量似乎是完全合理的,例如, conv1/sum1和conv2/sum1。

在我的代碼我是一個新的範圍內工作,但該圖將無法運行沒有數據被送入從最初的默認範圍的佔位符。是默認的範圍總是「在範圍內」出於某種原因?

如果圖邊可以跨越不同的作用域,並且不同作用域中的名稱不能共享,除非它們引用的是完全相同的節點,那麼這樣看起來會失去首先具有不同作用域的目的。我在這裏誤解了什麼?

謝謝!

回答

2

將訓練好的LSTM權重從一個圖表移動到另一個圖表的正確方法是什麼?

可以先(有保護對象保存參數)創建解碼圖,並創建一個GraphDef對象,你可以在更大的訓練圖形導入:

basegraph = tf.Graph() 
with basegraph.as_default(): 
    ***your graph*** 

traingraph = tf.Graph() 
with traingraph.as_default(): 
    tf.import_graph_def(basegraph.as_graph_def()) 
    ***your training graph*** 

確保您加載變量當你開始一個新圖形的會話。

我沒有使用這一功能的體驗,所以你可能要考慮它多一點

它是正確的說,開始一個新的會話「釋放資源」,但不刪除內存中的圖形?

是的,圖形對象仍持有它

在我看來,像「重複使用」功能允許Tensorflow當前的可變範圍之外具有相同名稱的搜索變量(在不同範圍內現有的),並在當前範圍內使用它們。它是否正確?如果是這樣,所有來自非當前範圍的鏈接到該變量的圖邊都會發生什麼?如果不是,那麼爲什麼Tensorflow會在兩個不同範圍內嘗試使用相同的變量名稱時會拋出錯誤?在兩個不同的範圍內定義具有相同名稱的兩個變量似乎是完全合理的,例如, conv1/sum1和conv2/sum1。

不,重用是爲了確定在現有名稱上使用get_variable時的行爲,當它爲true時,它將返回現有變量,否則它將返回一個新的名稱。通常tensorflow不應該拋出錯誤。你確定你使用tf.get_variable而不僅僅是tf.Variable嗎?

在我的代碼中,我在一個新的作用域內工作,但是圖形不會在沒有數據的情況下從最初的默認作用域引入佔位符。是默認的範圍總是「在範圍內」出於某種原因?

我真的不明白你的意思。並不總是必須使用。如果不需要佔位符來運行操作,則不必定義它。

如果圖邊可以跨越不同的作用域,並且不同作用域中的名稱不能被共享,除非它們引用的是完全相同的節點,那麼這看起來似乎打敗了首先具有不同作用域的目的。我在這裏誤解了什麼?

我覺得你的理解或範圍的使用是有缺陷的,見上面

+0

回覆:倒數第二個問題:我想訓練模型,然後在同一個Python腳本重用。我根本沒有使用保護程序。當我嘗試重新使用模型時,Tensorflow抱怨,因爲我沒有提供特定於培訓的佔位符。它們不需要重複使用,但它們仍然存在於圖表中。 – amm

+0

回覆:第一個問題:如果我必須更改輸入佔位符,該解決方案將如何工作?我正在使用minibatches進行培訓,但是在沒有minibatches的情況下重新使用模型,這似乎沒有改變輸入尺寸的機會。 – amm

+0

回覆:第三個問題:事實證明,MultiRNNCell是一個包裝器,它創建與包裝器對象本身不同的節點。有沒有文件介紹如何保存和恢復這種對象/節點集合?恢復後在節點上發生邊緣事件會發生什麼?當在新範圍內調用tf.get_variable()時,發生在節點上的邊事件會發生什麼?在此操作過程中,reuse = true是否會更改邊和節點之間的任何關係?我找不到解決這些問題的具體文件。 – amm

相關問題