2017-04-06 36 views
11

我正試圖在TensorFlow中重新開始一個模型訓練,通過拾取它離開的位置。我想使用最近添加的(0.12+我認爲)import_meta_graph()以便不重構圖。Python TensorFlow:如何使用優化器和import_meta_graph重新開始訓練?

我見過這方面的解決方案,例如Tensorflow: How to save/restore a model?,但我碰到AdamOptimizer的問題,具體而言,我得到一個ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used錯誤。 This can be fixed by initializing,但隨後我的模型值被清除!

還有其他的答案和一些完整的例子,但他們總是看起來更老,所以不包括新的import_meta_graph()方法,或沒有非張量優化器。我能找到的最接近的問題是tensorflow: saving and restoring session,但沒有最終的明確解決方案,而且這個例子非常複雜。

理想情況下,我想要一個簡單的可運行示例,從頭開始,停止,然後再次提取。我有一些工作(下),但也想知道我是否錯過了一些東西。當然,我不是唯一一個這樣做?

+0

我和AdamOptimizer有同樣的問題。我設法通過將我的操作放到集合中來讓事情發揮作用。這個例子幫助了我很多:http://www.seaandsailor.com/tensorflow-checkpointing.html –

回答

4

下面是我從閱讀文檔,其他類似的解決方案以及試驗和錯誤中得出的結論。這是一個關於隨機數據的簡單自動編碼器。如果跑了,然後又跑了,它會從停下的地方繼續跑下去(即第一次跑的成本函數從大約0.5 - > 0.3秒開始跑〜0.3)。除非我錯過了一些東西,否則所有的保存,構造函數,模型構建,add_to_collection都需要並按照精確的順序進行,但可能有一個更簡單的方法。

是的,加載與import_meta_graph圖形並不是真的需要在這裏,因爲代碼是正確的上面,但是我想在我的實際應用程序。

from __future__ import print_function 
import tensorflow as tf 
import os 
import math 
import numpy as np 

output_dir = "/root/Data/temp" 
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt") 

input_length = 10 
encoded_length = 3 
learning_rate = 0.001 
n_epochs = 10 
n_batches = 10 
if not os.path.exists(model_checkpoint_file_base + ".meta"): 
    print("Making new") 
    brand_new = True 

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in") 
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length], 
              -1.0/math.sqrt(input_length), 
              1.0/math.sqrt(input_length)), name="W_enc") 
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc") 
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded") 
    W_dec = tf.transpose(W_enc, name="W_dec") 
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec") 
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded") 
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost") 

    saver = tf.train.Saver() 
else: 
    print("Reloading existing") 
    brand_new = False 
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta") 
    g = tf.get_default_graph() 
    x_in = g.get_tensor_by_name("x_in:0") 
    cost = g.get_tensor_by_name("cost:0") 


sess = tf.Session() 
if brand_new: 
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    tf.add_to_collection("optimizer", optimizer) 
else: 
    saver.restore(sess, model_checkpoint_file_base) 
    optimizer = tf.get_collection("optimizer")[0] 

for epoch_i in range(n_epochs): 
    for batch in range(n_batches): 
     batch = np.random.rand(50, input_length) 
     _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch}) 
     print("batch_cost:", curr_cost) 
     save_path = tf.train.Saver().save(sess, model_checkpoint_file_base) 
2

我有同樣的問題,我只是想出了什麼是錯的,至少在我的代碼。

最後,我在saver.restore()中使用了錯誤的文件名。此功能必須在指定的文件名,不帶文件擴展名,就像saver.save()功能:的

saver.restore(sess, 'model-1') 

代替

saver.restore(sess, 'model-1.data-00000-of-00001') 

有了這個,我做你想要做什麼:從頭開始,停下來,然後再接。我不需要使用tf.train.import_meta_graph()函數從元文件初始化第二個保存程序,並且在初始化優化程序後,我不需要明確說明tf.initialize_all_variables()

我的完整模型恢復這個樣子的:

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    sess.run(tf.global_variables_initializer()) 
    saver.restore(sess, model-1) 

我認爲在協議V1你仍然不得不把.ckpt添加到文件名,併爲import_meta_graph()你仍然需要添加.meta,這可能會導致用戶之間有些混亂。也許這應該在文檔中更明確地指出。

0

當您在還原會話中創建保護程序對象時可能會出現問題。

在恢復會話中使用下面的代碼時,我獲得了和你一樣的錯誤。

saver = tf.train.import_meta_graph('tmp/hsmodel.meta') 
saver.restore(sess, tf.train.latest_checkpoint('tmp/')) 

但是,當我以這種方式改變了,

saver = tf.train.Saver() 
saver.restore(sess, "tmp/hsmodel") 

該錯誤已消失。 「tmp/hsmodel」是我在保存會話中給saver.save(sess,「tmp/hsmodel」)的路徑。

存儲和恢復培訓MNIST網絡會議(包含亞當優化器)的簡單例子在這裏。這有助於我比較我的代碼並解決問題。

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

相關問題