2017-03-01 66 views
3

在我當前的項目中,我訓練模型並每隔100個迭代步驟保存檢查點。檢查點文件全部保存到相同的目錄(model.ckpt-100,model.ckpt-200,model.ckpt-300等)。之後,我想根據所有保存的檢查點的驗證數據來評估模型,而不僅僅是最新的檢查點。張量流:在多個檢查點上運行模型評估

目前我的恢復點文件的代碼看起來是這樣的:

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
ckpt_list = saver.last_checkpoints 
print(ckpt_list) 
if ckpt and ckpt.model_checkpoint_path: 
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    # extract global_step from it. 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 
else: 
    print('No checkpoint file found') 
    return 

然而,這僅恢復最新保存的檢查點文件。那麼如何在所有保存的檢查點文件上編寫循環?我嘗試使用saver.last_checkpoints獲得檢查點文件的列表,但是,返回的列表是空的。

任何幫助將不勝感激,在此先感謝!

+0

如何準確保存模型?您是在自己建立輸出文件的名稱,還是在調用'saver.save(..)'時使用'global_step'參數? – kaufmanu

回答

1

您可以通過在目錄中的文件重複:

import os 

dir_path = './' #change that to wherever your files are 
ckpt_files = [f for f in os.listdir(dir_path) if os.path.isfile(
    os.path.join(dir_path, f)) and 'ckpt' in f] 

for ckpt_file in ckpt_files: 
    saver.restore(sess, dir_path + ckpt_file) 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 

    # Do your thing 

在上面的列表理解添加更多的條件,更有選擇性,如:and 'meta' not in f等取決於什麼是在DIR和你金丹版本有

0

謝謝你。但是我得到的錯誤

「NotFoundError(見上文回溯):在檢查點沒有找到關鍵CONV 2 /偏見/ ExponentialMovingAverage」

其中CONV 2 /偏見是一個變量的作用域。我使用保存版本v2。

同時我嘗試了不同的(比特更簡單的代碼),並得到了同樣的錯誤:

誤差實際上發生在這一段代碼(在variables_to_restore =)

fileBaseName = FLAGS.checkpoint_dir + '/model.ckpt-' 

    for global_step in range(0,100,10): # range over the global steps where checkpoints were saved 
    x_str = str(global_step) 
    fileName = fileBaseName+x_str 
    print(fileName) 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 

    #restore checkpoint file 
    saver.restore(sess, fileName) 

# Restore the moving average version of the learned variables for eval. 
variable_averages = tf.train.ExponentialMovingAverage(
    MOVING_AVERAGE_DECAY) 
variables_to_restore = variable_averages.variables_to_restore() 
saver = tf.train.Saver(variables_to_restore) 

我不知道如何解決這個錯誤。它可能與保護程序版本有關嗎?或者必須是檢查點保存部分的錯誤?

非常感謝。 TheJude