我用tensorflow運行了一個培訓工作,並在驗證集上得到了以下損失曲線。網絡在第6000次迭代後開始過度配合。所以我想在過度使用之前得到這個模型。Tensorflow:以最小驗證錯誤保存模型
我的訓練碼是一樣的東西如下:
train_step = ......
summary = tf.scalar_summary(l1_loss.op.name, l1_loss)
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)
saver = tf.train.Saver()
for i in xrange(20000):
batch = get_next_batch(batch_size)
sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})
if (i+1) % 100 == 0:
saver.save(sess, "checkpoint/net", global_step = i+1)
summary_str = sess.run(summary, feed_dict=validation_feed_dict)
summary_writer.add_summary(summary_str, i+1)
summary_writer.flush()
訓練結束之後,只有五保存檢查站(19600,19700,19800,19900,20000)。根據驗證錯誤,有什麼辦法讓tensorflow保存檢查點?
P.S.我知道tf.train.Saver
有一個max_to_keep
的論點,原則上可以保存所有的檢查點。但那不是我想要的(除非它是唯一的選擇)。我希望保存者保持檢查點迄今最小的驗證損失。那可能嗎?