如何將單個神經網絡的權重保存在張量流圖中,以便它可以在不同的程序中加載到具有相同體系結構的網絡中?在張量流圖中保存單個神經網絡的權重
我的訓練碼僅需要3個其他神經網絡用於訓練過程。如果我要使用saver.save(sess, 'my-model)'
,它不會保存張量流圖中的所有變量嗎?這對我的用例來說似乎不正確。
也許這是來自我對tensorflow應該如何工作的誤解。我正確地處理這個問題嗎?
如何將單個神經網絡的權重保存在張量流圖中,以便它可以在不同的程序中加載到具有相同體系結構的網絡中?在張量流圖中保存單個神經網絡的權重
我的訓練碼僅需要3個其他神經網絡用於訓練過程。如果我要使用saver.save(sess, 'my-model)'
,它不會保存張量流圖中的所有變量嗎?這對我的用例來說似乎不正確。
也許這是來自我對tensorflow應該如何工作的誤解。我正確地處理這個問題嗎?
最好的方法是使用tensorflow變量作用域。假設你有model_1,model_2和model_3你只想保存model_1:
首先,在你的訓練代碼中定義的車型:
with tf.variable_scope('model_1'):
model one declaration here
...
with tf.variable_scope('model_2'):
model one declaration here
...
with tf.variable_scope('model_3'):
model one declaration here
...
接下來,定義了model_1的變量保護:
model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1")
saver = tf.train.Saver(model_1_variables)
雖然訓練可以節省檢查點,就像你提到的:
saver.save(sess, 'my-model')
您的訓練後做,你要恢復的權重的評估代碼,請確保您定義model_1和保護以同樣的方式:
with tf.variable_scope('model_1'):
model one declaration here
...
model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1")
saver = tf.train.Saver(model_1_variables)
sess = tf.Session()
saver.restore(sess, 'my-model')`
完美工作。謝謝。 –
你可以選擇你想保存(HTTPS哪個變量:// WWW .tensorflow.org/programmers_guide/saved_model)。只需寫:saver = tf.train.Saver({「my_var」:my_var}) –