2017-10-06 50 views
0

如何將單個神經網絡的權重保存在張量流圖中,以便它可以在不同的程序中加載到具有相同體系結構的網絡中?在張量流圖中保存單個神經網絡的權重

我的訓練碼僅需要3個其他神經網絡用於訓練過程。如果我要使用saver.save(sess, 'my-model)',它不會保存張量流圖中的所有變量嗎?這對我的用例來說似乎不正確。

也許這是來自我對tensorflow應該如何工作的誤解。我正確地處理這個問題嗎?

+0

你可以選擇你想保存(HTTPS哪個變量:// WWW .tensorflow.org/programmers_guide/saved_model)。只需寫:saver = tf.train.Saver({「my_var」:my_var}) –

回答

1

最好的方法是使用tensorflow變量作用域。假設你有model_1,model_2和model_3你只想保存model_1:

首先,在你的訓練代碼中定義的車型:

with tf.variable_scope('model_1'): 
    model one declaration here 
    ... 
with tf.variable_scope('model_2'): 
    model one declaration here 
    ... 
with tf.variable_scope('model_3'): 
    model one declaration here 
    ... 

接下來,定義了model_1的變量保護:

model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1") 
saver = tf.train.Saver(model_1_variables) 

雖然訓練可以節省檢查點,就像你提到的:

saver.save(sess, 'my-model') 

您的訓練後做,你要恢復的權重的評估代碼,請確保您定義model_1和保護以同樣的方式:

with tf.variable_scope('model_1'): 
    model one declaration here 
    ... 
model_1_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="model_1") 
saver = tf.train.Saver(model_1_variables) 
sess = tf.Session() 
saver.restore(sess, 'my-model')` 
+0

完美工作。謝謝。 –