0

我已經訓練了一個帶有tensorflow的模型並在訓練過程中使用了批量規範化。批量規範化要求用戶傳遞一個布爾值,稱爲is_training,以設置模型是處於訓練階段還是測試階段。恢復訓練後的張量流模型,編輯與節點相關聯的值並保存它

當模型進行訓練,如下圖所示

is_training = tf.constant(True, dtype=tf.bool, name='is_training') 

我已經保存了訓練模型is_training被設置爲常數,這些文件包括檢查點,.META文件,文件的.index和。數據。我想恢復模型並使用它進行推理。 該模型不能被重新訓練。因此,我想恢復現有模型,將is_training的值設置爲False,然後將模型保存回來。 如何編輯與該節點關聯的布爾值,並再次保存模型?

+0

,如果使用它本來就容易'is_training = tf.Variable..'而不是恆定的 –

+0

有爲什麼'is_training'需要一個tensorflow常數的原因是什麼?它不能是一個Python布爾?請注意,將'is_training'更改爲python bool不應在恢復模型時發生錯誤。 – GeertH

+0

@GeertH它可能是,問題是我加載模型後如何將'is_training'設置爲'False',然後將其保存回來。因此,當它再次恢復時,節點的值爲「False」。 – dpk

回答

1

您可以使用tf.train.import_meta_graph的參數input_map將圖張量重新映射到更新的值。

config = tf.ConfigProto(allow_soft_placement=True) 
with tf.Session(config=config) as sess: 
    # define the new is_training tensor 
    is_training = tf.constant(False, dtype=tf.bool, name='is_training') 

    # now import the graph using the .meta file of the checkpoint 
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training}) 

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model') 

    # save updated graph and variables values 
    saver.save(sess, '/path/to/new-model-name') 
+0

如果使用了input_map,上面的代碼會拋出一個錯誤'ValueError:tf.import_graph_def()需要一個非空名稱' – dpk

+0

我已經使用'tensorflow == 1.2.0'測試了這段代碼,希望它有幫助;也不是'tf.import_graph_def'。看我的代碼。 –

+0

我試過你的代碼,錯誤是這條線拋出的,'saver = tf.train.import_meta_graph(r'D:\ code \ iprings \ k-fold-model \ VanillaCNN_24.0000.meta',input_map = {'is_training':is_training})' – dpk

相關問題