2017-03-29 72 views
6

在Tensorflow中,我們可以使用Between-graph Replication構建和創建多個Tensorflow會話進行分佈式培訓。 MonitoredTrainingSession()座標多個Tensorflow會話,並且有checkpoint_dirMonitoredTrainingSession()的參數來恢復Tensorflow會話/圖形。現在我有以下問題:`MonitoredTrainingSession()`如何與「恢復」和「測試模式」一起工作?

  1. 我們通常使用的tf.train.Saver()對象通過saver.restore(...)恢復Tensorflow圖。但我們如何使用MonitoredTrainingSession()來恢復它們?
  2. 由於我們運行多個進程,並且每個進程都構建並創建一個Tensorflow會話進行培訓,所以我不知道在訓練之後是否還需要運行多個進程以進行測試(或預測)。換句話說,MonitoredTrainingSession()如何與測試(或預測)模式一起工作?

我讀了Tensorflow Doc,但沒有找到這兩個問題的答案。我非常感謝任何人有解決方案。謝謝!

回答

-1
  1. 看起來恢復是爲您處理的。詮釋他的API文檔,它說,調用MonitoredTrainingSession返回MonitoredSession的一個實例,它在創建「......如果檢查點存在恢復變量...」

  2. 退房tf.contrib.learn.Estimator(..).predict(..)更具體tf.contrib.learn.Estimator(..)._infer_model(..)方法herehere。他們還在那裏創建一個MonitoredSession。

0

簡短的回答:

  1. 您需要通過全球一步傳遞給mon_sess.run優化。這使得保存和檢索保存的檢查點成爲可能。
  2. 可以通過單個MonitoredTrainingSession同時運行培訓+交叉驗證會話。首先,您需要通過培訓批次並通過同一圖表的不同流來交叉驗證批次(我建議您查看this guide瞭解如何執行此操作的信息)。其次,你必須 - 對mon_sess.run() - 傳遞訓練流的優化器,以及交叉驗證流的丟失參數(你想跟蹤的參數)。如果要與訓練分開運行測試會話,只需通過圖運行測試集,然後在圖中只運行test_loss(/要跟蹤的其他參數)。有關如何完成的更多細節,請看下面。

龍答:

我會更新我的答案,因爲我自己弄的有什麼可以與tf.train.MonitoredSession做一個更好的視野(tf.train.MonitoredTrainingSession只是建立的專用版本tf.train.MonitoredSession,在source code中可以看到)。

以下是顯示如何將檢查點每隔5秒保存到'./ckpt_dir'的示例代碼。當被中斷時,將重新啓動其最後一次保存的檢查點:

def train(inputs, labels_onehot, global_step): 
    out = tf.contrib.layers.fully_connected(
          inputs, 
          num_outputs=10, 
          activation_fn=tf.nn.sigmoid) 
    loss = tf.reduce_mean(
      tf.reduce_sum(
       tf.nn.sigmoid_cross_entropy_with_logits(
          logits=out, 
          labels=labels_onehot), axis=1)) 
    train_op = opt.minimize(loss, global_step=global_step) 
    return train_op 

with tf.Graph().as_default(): 
    global_step = tf.train.get_or_create_global_step() 
    inputs = ... 
    labels_onehot = ... 
    train_op = train(inputs, labels_onehot, global_step) 

    with tf.train.MonitoredTrainingSession(
     checkpoint_dir='./ckpt_dir', 
     save_checkpoint_secs=5, 
     hooks=[ ... ] # Choose your hooks 
    ) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

什麼在MonitoredTrainingSession是發生在爲了實現這一目標實際上是三兩件事:

  1. 的tf.train.MonitoredTrainingSession創建tf.train.Scaffold對象,它像蜘蛛網一樣工作;它收集您需要訓練的部分,保存並加載模型。
  2. 它創建一個tf.train.ChiefSessionCreator對象。我對這個知識的理解是有限的,但從我的理解來看,它被用於當你的tf算法分佈在多個服務器上時。我認爲它告訴運行該文件的計算機它是主計算機,並且在這裏檢查點目錄應該被保存,並且記錄器應該在這裏記錄它們的數據等。
  3. 它創建tf.train.CheckpointSaverHook,用於保存檢查點。

爲了使其工作,必須將tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator傳遞給檢查點目錄和腳手架的相同引用。如果上面其在例如參數tf.train.MonitoredTrainingSession將被用3層以上的組件來實現,這將是這個樣子:

checkpoint_dir = './ckpt_dir' 

scaffold = tf.train.Scaffold() 
saverhook = tf.train.CheckpointSaverHook(
    checkpoint_dir=checkpoint_dir, 
    save_secs=5 
    scaffold=scaffold 
) 
session_creator = tf.train.ChiefSessionCreator(
    scaffold=scaffold, 
    checkpoint_dir=checkpoint_dir 
) 

with tf.train.MonitoredSession(
    session_creator=session_creator, 
    hooks=[saverhook]) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

爲了做火車+交叉驗證會話,你需要簡單地通過相同的曲線圖可以在兩個不同的組,然後運行(在while循環以上):

mon_sess.run([train_op, cross_validation_loss]) 

這將運行對於訓練集的訓練優化器,以及用於在驗證所述validation_loss參數組。如果您的圖形正確實施,這意味着該圖形只會在訓練集上進行訓練,並且只能在交叉驗證集上進行驗證。

相關問題