2017-06-05 195 views
1

我完全失去了對的tensorflow保護方法。Tensorflow saver.restore()不恢復網絡

我試圖遵循的基本tensorflow深層神經網絡模型的教程。我想弄清楚如何訓練網絡幾次迭代,然後在另一個會話中加載模型。

with tf.Session() as sess: 
    graph = tf.Graph() 
    x = tf.placeholder(tf.float32,shape=[None,784]) 
    y_ = tf.placeholder(tf.float32, shape=[None,10]) 

    sess.run(global_variables_initializer()) 

    #Define the Network 
    #(This part is all copied from the tutorial - not copied for brevity) 
    #See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/ 

跳過培訓。

#Train the Network 
    train_step = tf.train.AdamOptimizer(1e-4).minimize(
        cross_entropy,global_step=global_step) 
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 

    saver = tf.train.Saver() 

    for i in range(101): 
     batch = mnist.train.next_batch(50) 
     if i%100 == 0: 
     train_accuracy = accuracy.eval(feed_dict= 
          {x:batch[0],y_:batch[1]}) 
     print 'Step %d, training accuracy %g'%(i,train_accuracy) 
      train_step.run(feed_dict={x:batch[0], y_: batch[1]}) 
     if i%100 == 0: 
      print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
         mnist.test.images, y_: mnist.test.labels}) 

     saver.save(sess,'./mnist_model') 

控制檯打印出:

步驟0,訓練精度0.16

測試精度0.0719

步驟100,訓練精度0.88

測試精度0.8734

接下來,我要加載模型

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('mnist_model.meta') 
    saver.restore(sess,tf.train.latest_checkpoint('./')) 
    sess.run(tf.global_variables_initializer()) 

現在我想重新測試,看看模型加載

print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
         mnist.test.images, y_: mnist.test.labels}) 

控制檯打印出:

測試精度0.1151

它似乎沒有顯示模型正在保存任何數據?我究竟做錯了什麼?

+0

你不應該運行'sess.run(tf.global_variables_initializer())'恢復權重後。這將重置您的所有權重 – martianwars

回答

0

當您保存您的模型,一般而局部變量是不是所有的全局變量保存在外部文件。您可以查看此answer以瞭解其差異。

您的恢復代碼中的錯誤正在調用tf.global_variable_initializer()saver.restore()。該saver.restore文檔提到,

變量恢復沒有被初始化,如恢復本身就是一種方式來初始化變量。

因此,嘗試刪除線,

sess.run(tf.global_variables_initializer()) 

理論上,應該將其替換爲,

sess.run(tf.local_variables_initializer()) 
+1

謝謝,這似乎已經解決了我的問題!如果文檔聲明'saver.restore()'是一個初始化過程,那麼'sess.run(tf。local_variables_initializer())用於任何目的? 這似乎也表明,教程,如[一個快速完整的教程來保存和恢復Tensorflow模型](http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-教程/)顯示不正確的用法,不是嗎? –

+0

你應該檢查['tf.local_variables()'](https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/local_variables)。如果這個列表非空,則需要它 – martianwars