2016-04-04 142 views
2

您好,我想從tensorflow中調整VGG模型。我有兩個問題。從張量流模型獲取重量

如何從網絡中獲取權重? trainable_variables爲我返回空列表。

我從這裏使用現有的模型:https://github.com/ry/tensorflow-vgg16。 我找到關於獲得重量的帖子,但是這對我來說不適用,因爲import_graph_def。 Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf 
import PIL.Image 
import numpy as np 

with open("../vgg16.tfmodel", mode='rb') as f: 
    fileContent = f.read() 

graph_def = tf.GraphDef() 
graph_def.ParseFromString(fileContent) 

images = tf.placeholder("float", [None, 224, 224, 3]) 

tf.import_graph_def(graph_def, input_map={ "images": images }) 
print("graph loaded from disk") 

graph = tf.get_default_graph() 

cat = np.asarray(PIL.Image.open('../cat224.jpg')) 
print(cat.shape) 
init = tf.initialize_all_variables() 

with tf.Session(graph=graph) as sess: 
    print(tf.trainable_variables()) 
    sess.run(init) 
+0

儘量避免一次詢問多個問題。如果你找不到任何答案,並且你認爲它們都是有價值的問題,請隨時在單獨的帖子中詢問他們。 – Giewev

回答

4

pretrained VGG-16 model編碼所有模型參數作爲tf.constant() OPS的。 (例如,請參閱tf.constant()here的調用。)因此,模型參數不會出現在tf.trainable_variables()中,並且在沒有大量手術的情況下模型不可變:您需要用以tf.Variable開頭的對象替換常量節點相同的價值以繼續訓練。

通常,在導入再訓練圖時,應該使用tf.train.import_meta_graph()函數,因爲此函數會加載其他元數據(包括變量集合)。 tf.import_graph_def()函數是較低級別的,並且不會填充這些集合。