2016-04-22 337 views
1

我是TensorFlow的初學者,目前正在培訓CNN。Tensorflow:保存和恢復模型參數

我爲了節省模型使用的參數,使用節電器,但我有擔憂這是否會本身存儲在模型中使用的所有變量,以及足以將值恢復爲重新運行在訓練的網絡上執行分類/測試的程序。

讓我們看一下TensorFlow給出的着名示例MNIST。

在這個例子中,我們有一堆卷積塊,所有塊都有權重,並且偏置變量在程序運行時被初始化。

W_conv1 = init_weight([5,5,1,32]) 
b_conv1 = init_bias([32]) 

處理好幾個圖層後,我們創建一個會話,並初始化添加到圖形中的所有變量。

sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
saver = tf.train.Saver() 

這裏,是否有可能作出評論saver.save代碼,並通過培訓後saver.restore(SESS,FILE_PATH)取代它,以恢復體重,偏見等,後面的參數到圖表?這是應該如何?

for i in range(1000): 
... 

    if i%500 == 0: 
    saver.save(sess,"model%d.cpkt"%(i)) 

我目前的訓練大數據集,因此終止,並重新啓動培訓是浪費時間和資源,所以我要求有人請澄清之前,我開始訓練。

+0

這是有點不清楚你在問什麼。 「評論saver.save代碼,並用saver.restore(sess,file_path)替換它」你不想存儲你的訓練值並且重新設置以前的訓練(通過恢復)? 「如此終止,重新開始培訓是一種浪費」。這意味着當您完成所有培訓後,您想要保存一次模型? –

+0

@Sung Kim:你的後一個問題的答案是肯定的。我不打算用存儲的值重新開始訓練,而是在完成訓練後簡單地將模型保存一次。因爲在Matlab中,這非常簡單,事實上,這是我第一次編程Python和TensorFlow,所以我不知道是否有其他優雅的方式來保存參數。 –

回答

4

如果你想保存最後的結果只有一次,你可以這樣做:

with tf.Session() as sess: 
    for i in range(1000): 
    ... 


    path = saver.save(sess, "model.ckpt") # out of the loop 
    print "Saved:", path 

在其他程序中,可以使用從saver.save返回預測什麼的路徑加載模型。你可以在https://github.com/sugyan/tensorflow-mnist看到一些例子。

1

根據here和Sung Kim解決方案的說明,我寫了一個非常簡單的模型來準確解決這個問題。基本上這樣你需要從同一個類創建一個對象,並從保存器中恢復它的變量。你可以找到這個解決方案的例子here