2016-08-26 116 views
5

我想要做的是同時運行多個預先訓練的Tensorflow網絡。由於每個網絡中的一些變量的名稱可以相同,因此常見的解決方案是在創建網絡時使用名稱範圍。但問題是我已經訓練了這些模型並將訓練好的變量保存在多個檢查點文件中。在我創建網絡時使用名稱範圍後,我無法從檢查點文件加載變量。同時運行多個預先訓練的Tensorflow網絡

例如,我已經培訓了一個AlexNet,我想比較兩組變量,一組來自曆元10(保存在文件epoch_10.ckpt中),另一組來自曆元50(保存在文件epoch_50.ckpt)。因爲這兩個網絡完全相同,所以裏面的變量名稱是相同的。我可以用

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

然而,因爲當我訓練的這個網,我沒有用一個名字的範圍,我不能加載從.ckpt文件訓練有素的變量創建兩個網。儘管我可以在訓練網絡時將名稱範圍設置爲「net1」,但這會阻止我加載net2的變量。

我曾嘗試:

with tf.name_scope("net1"): 
    mySaver.restore(sess, 'epoch_10.ckpt') 
with tf.name_scope("net2"): 
    mySaver.restore(sess, 'epoch_50.ckpt') 

這是行不通的。

解決此問題的最佳方法是什麼?

回答

10

最簡單的解決方案是創建一個使用單獨的圖形各型號不同的會話:

# Build a graph containing `net1`. 
with tf.Graph().as_default() as net1_graph: 
    net1 = CreateAlexNet() 
    saver1 = tf.train.Saver(...) 
sess1 = tf.Session(graph=net1_graph) 
saver1.restore(sess1, 'epoch_10.ckpt') 

# Build a separate graph containing `net2`. 
with tf.Graph().as_default() as net2_graph: 
    net2 = CreateAlexNet() 
    saver2 = tf.train.Saver(...) 
sess2 = tf.Session(graph=net1_graph) 
saver2.restore(sess2, 'epoch_50.ckpt') 

如果這不出於某種原因,你必須使用一個tf.Session(如因爲你要的結果從兩個網絡中的另一TensorFlow計算相結合),最好的解決辦法是:

  1. 創建名稱範圍的不同的網絡,你已經這樣做,和
  2. 爲兩個網絡創建單獨的tf.train.Saver實例,並使用附加參數重新映射變量名稱。

constructing的儲戶,就可以通過一本字典爲var_list說法,在檢查點(即沒有名稱範圍前綴)以您在每個模型創建的tf.Variable對象映射變量的名稱。

你可以建設var_list編程,你應該能夠做到像下面這樣:

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint. 
net1_varlist = {v.name.lstrip("net1/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")} 
net1_saver = tf.train.Saver(var_list=net1_varlist) 

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint. 
net2_varlist = {v.name.lstrip("net2/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")} 
net2_saver = tf.train.Saver(var_list=net2_varlist) 

# ... 
net1_saver.restore(sess, "epoch_10.ckpt") 
net2_saver.restore(sess, "epoch_50.ckpt") 
+0

太棒了! – denru

+0

使用lstrip剝離前綴可能會導致錯誤的結果。請使用切片代替。代碼的其他部分完美地工作。另一個問題是,我發現一個變量的名稱有一個像「:0」,「:1」的後綴。在將變量存儲到檢查點文件之前,我需要擺脫這個後綴嗎? – denru

+0

任何人都試過這個答案?我遇到的問題與'恢復'功能沒有做任何事情:http://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session – TheCriticalImperitive