2016-12-27 102 views
1

我已經預訓練了網絡,並且我正在試圖僅僅獲取它的一部分(子圖)tf圖以及變量和保存對象。在tensorflow中提取子圖

這就是我正在做它:

subgraph = tf.graph_util.extract_sub_graph(default_graph, list of nodes to preserve) 
tf.reset_default_graph() 
tf.import_graph_def(subgraph) 

然而,這將刪除所有變量(當我打電話reset_default_graph)。即使如果我明確地將變量的操作節點(僅「變量」類型操作)添加到「要保留的節點列表」中。

如何在保留變量值的同時保留較大圖的子圖? 這是另外一些新的節點「保留列表」的問題嗎?

圖形節點和變量之間的關係仍然不清楚,教程僅僅提到創建變量會在圖形中創建一些操作(節點)。

回答

0

我覺得你在做什麼看起來不錯。正如你所說的,一個變量只是一個簡單的操作(圖中的一個節點),用於輸出特定值的張量。您應該能夠將變量節點添加到列表中以保留它們,就像您已經在做的那樣。你可以使用print(sess.graph_def)來確保你提供的名字是正確的嗎?

+0

變量是一組連接的操作。通常,它以ops:「variable [variable]」,「assign [assign]」,「read [identity]」(第一部分是名稱,方括號用於類型)以及整套初始化操作符爲主。問題在於圖導出會以一種不被認爲是變量的方式來削減變量結構。選擇所有必需的操作非常麻煩 - 而且不是最聰明的方法。 – Pietrko

+0

是的,的確如此。如果您查看extract_sub_graph的函數接口(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/graph_util_impl.py#L110),注意到它只是簡單的函數,沒有任何智能處理對於變量,「選擇所有需要的操作」可能仍然是你最好的選擇。好消息是我認爲你可以編寫一個簡單的函數(使用graph_def作爲輸入)來自動執行這個繁瑣的選擇變量相關節點的過程。 –

+0

好吧,我希望我可以避免這種情況,也許存在一些乾淨而快速的方式來使用現有的API。謝謝。 – Pietrko