簡短的回答:
- 您需要通過全球一步傳遞給mon_sess.run優化。這使得保存和檢索保存的檢查點成爲可能。
- 可以通過單個MonitoredTrainingSession同時運行培訓+交叉驗證會話。首先,您需要通過培訓批次並通過同一圖表的不同流來交叉驗證批次(我建議您查看this guide瞭解如何執行此操作的信息)。其次,你必須 - 對mon_sess.run() - 傳遞訓練流的優化器,以及交叉驗證流的丟失參數(你想跟蹤的參數)。如果要與訓練分開運行測試會話,只需通過圖運行測試集,然後在圖中只運行test_loss(/要跟蹤的其他參數)。有關如何完成的更多細節,請看下面。
龍答:
我會更新我的答案,因爲我自己弄的有什麼可以與tf.train.MonitoredSession做一個更好的視野(tf.train.MonitoredTrainingSession只是建立的專用版本tf.train.MonitoredSession,在source code中可以看到)。
以下是顯示如何將檢查點每隔5秒保存到'./ckpt_dir'的示例代碼。當被中斷時,將重新啓動其最後一次保存的檢查點:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
什麼在MonitoredTrainingSession是發生在爲了實現這一目標實際上是三兩件事:
- 的tf.train.MonitoredTrainingSession創建tf.train.Scaffold對象,它像蜘蛛網一樣工作;它收集您需要訓練的部分,保存並加載模型。
- 它創建一個tf.train.ChiefSessionCreator對象。我對這個知識的理解是有限的,但從我的理解來看,它被用於當你的tf算法分佈在多個服務器上時。我認爲它告訴運行該文件的計算機它是主計算機,並且在這裏檢查點目錄應該被保存,並且記錄器應該在這裏記錄它們的數據等。
- 它創建tf.train.CheckpointSaverHook,用於保存檢查點。
爲了使其工作,必須將tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator傳遞給檢查點目錄和腳手架的相同引用。如果上面其在例如參數tf.train.MonitoredTrainingSession將被用3層以上的組件來實現,這將是這個樣子:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
爲了做火車+交叉驗證會話,你需要簡單地通過相同的曲線圖可以在兩個不同的組,然後運行(在while循環以上):
mon_sess.run([train_op, cross_validation_loss])
這將運行對於訓練集的訓練優化器,以及用於在驗證所述validation_loss參數組。如果您的圖形正確實施,這意味着該圖形只會在訓練集上進行訓練,並且只能在交叉驗證集上進行驗證。