2017-10-11 32 views
2

我一直在訓練TensorFlow模型約一週,偶爾會有微調。TensorFlow:NotFoundError:在檢查點找不到關鍵

今天,當我試圖微調模型我得到了錯誤:

tensorflow.python.framework.errors_impl.NotFoundError: Key conv_classifier/loss/total_loss/avg not found in checkpoint 
[[Node: save/RestoreV2_37 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_37/tensor_names, save/RestoreV2_37/shape_and_slices)]] 

使用inspect_checkpoint.py我看到檢查點文件現已在它兩個空層:

... 
conv_decode4/ort_weights/Momentum (DT_FLOAT) [7,7,64,64] 
loss/cross_entropy/avg (DT_FLOAT) [] 
loss/total_loss/avg (DT_FLOAT) [] 
up1/up_filter (DT_FLOAT) [2,2,64,64] 
... 

如何我能解決這個問題嗎?

SOLUTION:下面編輯爲清楚起見

繼mrry建議:

code_to_checkpoint_variable_map = {var.op.name: var for var in tf.global_variables()} 
for code_variable_name, checkpoint_variable_name in { 
    "inference/conv_classifier/weight_loss/avg" : "loss/weight_loss/avg", 
    "inference/conv_classifier/loss/total_loss/avg" : "loss/total_loss/avg", 
    "inference/conv_classifier/loss/cross_entropy/avg": "loss/cross_entropy/avg", 
}.items(): 
    code_to_checkpoint_variable_map[checkpoint_variable_name] = code_to_checkpoint_variable_map[code_variable_name] 
    del code_to_checkpoint_variable_map[code_variable_name] 

saver = tf.train.Saver(code_to_checkpoint_variable_map) 
saver.restore(sess, tf.train.latest_checkpoint('./logs')) 
+0

檢查點文件中的關鍵是「loss/weight_loss/avg」嗎?第一個異常消息表明它不是。 (我不明白你做的其他修改,但完整的代碼塊看起來很合理。) – mrry

回答

2

幸運的是,它並不像你的關卡是腐敗,而是一些變量在你的程序已被重新命名。我假設名爲"loss/total_loss/avg"的檢查點值應該恢復到名爲"conv_classifier/loss/total_loss/avg"的變量。您可以通過在創建tf.train.Saver時通過自定義var_list來解決此問題。

name_to_var_map = {var.op.name: var for var in tf.global_variables()} 

name_to_var_map["loss/total_loss/avg"] = name_to_var_map[ 
    "conv_classifier/loss/total_loss/avg"] 
del name_to_var_map["conv_classifier/loss/total_loss/avg"] 

# Depending on how the names have changed, you may also need to do: 
# name_to_var_map["loss/cross_entropy/avg"] = name_to_var_map[ 
#  "conv_classifier/loss/cross_entropy/avg"] 
# del name_to_var_map["conv_classifier/loss/cross_entropy/avg"] 

saver = tf.train.Saver(name_to_var_map) 

然後,您可以使用saver.restore()恢復您的模型。或者,您可以使用此方法恢復模型,並使用默認構造的tf.train.Saver將其保存爲規範格式。

+0

我有點困惑:'name_to_var_map [「loss/total_loss/avg」]'不需要是(實際上不應該)),因爲我們正在分配它。只是一個健全的檢查:是'name_to_var_map [「conv_classifier/loss/total_loss/avg」]'存在? – mrry

+0

這就是我期待的,是的!我試着用Python 2.7和3.4來改變字典理解,它似乎工作。也許增加'name_to_var_map = dict(name_to_var_map)'會使它變得可變嗎? – mrry

+0

啊,對不起:原來的詞典理解應該使用'var.op.name'作爲鍵,而不是'var.name'。 (「Variable.name」屬性中的'「:0」是一個[歷史工件](https://stackoverflow.com/a/36156697/3574081)。)這是否解決了問題? – mrry