2016-08-31 67 views
4

我用tensorflow運行了一個培訓工作,並在驗證集上得到了以下損失曲線。網絡在第6000次迭代後開始過度配合。所以我想在過度使用之前得到這個模型。Tensorflow:以最小驗證錯誤保存模型

loss

我的訓練碼是一樣的東西如下:

train_step = ...... 
summary = tf.scalar_summary(l1_loss.op.name, l1_loss) 
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph) 
saver = tf.train.Saver() 
for i in xrange(20000): 
    batch = get_next_batch(batch_size) 
    sess.run(train_step, feed_dict = {x: batch.x, y:batch.y}) 
    if (i+1) % 100 == 0: 
     saver.save(sess, "checkpoint/net", global_step = i+1) 
     summary_str = sess.run(summary, feed_dict=validation_feed_dict) 
     summary_writer.add_summary(summary_str, i+1) 
     summary_writer.flush() 

訓練結束之後,只有五保存檢查站(19600,19700,19800,19900,20000)。根據驗證錯誤,有什麼辦法讓tensorflow保存檢查點?

P.S.我知道tf.train.Saver有一個max_to_keep的論點,原則上可以保存所有的檢查點。但那不是我想要的(除非它是唯一的選擇)。我希望保存者保持檢查點迄今最小的驗證損失。那可能嗎?

回答

5

您需要計算驗證集上的分類準確性,並跟蹤迄今爲止所見最好的分類準確性,並且只有在驗證準確性發現改進後才寫入檢查點。

如果數據集和/或模型很大,那麼您可能必須將驗證集分成批以適應內存中的計算。

本教程介紹瞭如何正確地做你想要的:

https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb

它也可以作爲一個簡短的視頻:

https://www.youtube.com/watch?v=Lx8JUJROkh0

0

在你的session.run中,你需要明確地詢問損失。然後創建一個包含您最近一次評估損失的列表,並且只有當前評估損失小於最後兩次存儲損失才能創建檢查點。