2
變量定義

我tensorflow模型如下:如何保存tensorflow模型(省略標籤張量),沒有定義

X = tf.placeholder(tf.float32, [None,training_set.shape[1]],name = 'X') 
Y = tf.placeholder(tf.float32,[None,training_labels.shape[1]], name = 'Y') 
A1 = tf.contrib.layers.fully_connected(X, num_outputs = 50, activation_fn = tf.nn.relu) 
A1 = tf.nn.dropout(A1, 0.8) 
A2 = tf.contrib.layers.fully_connected(A1, num_outputs = 2, activation_fn = None) 
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = A2, labels = Y))  
global_step = tf.Variable(0, trainable=False) 
start_learning_rate = 0.001 
learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, 200, 0.1, True) 
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 

現在我要救這個模型遺漏張YY是標籤張量對於培訓,X是實際的輸入)。同時在提及freeze_graph.py時提及輸出節點時,我應該提及"A2"還是以其他名稱保存?

回答

3

雖然您尚未手動定義變量,但上面的代碼片段實際上包含15個可保存的變量。你可以使用這個內部tensorflow功能看到他們:

from tensorflow.python.ops.variables import _all_saveable_objects 
for obj in _all_saveable_objects(): 
    print(obj) 

對於上面的代碼,它產生以下列表:

<tf.Variable 'fully_connected/weights:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases:0' shape=(2,) dtype=float32_ref> 
<tf.Variable 'Variable:0' shape=() dtype=int32_ref> 
<tf.Variable 'beta1_power:0' shape=() dtype=float32_ref> 
<tf.Variable 'beta2_power:0' shape=() dtype=float32_ref> 
<tf.Variable 'fully_connected/weights/Adam:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/weights/Adam_1:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases/Adam:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases/Adam_1:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights/Adam:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights/Adam_1:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases/Adam:0' shape=(2,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases/Adam_1:0' shape=(2,) dtype=float32_ref> 

有來自fully_connected層變量和幾個從亞當來優化器(請參閱this question)。請注意,此列表中沒有XY佔位符,因此不需要排除它們。當然,這些張量存在於元圖中,但它們沒有任何價值,因此無法保存。

_all_saveable_objects()列表是tensorflow保存程序默認保存的內容,如果未明確提供變量的話。因此,回答你的主要的問題很簡單:

saver = tf.train.Saver() # all saveable objects! 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    saver.save(sess, "...") 

有沒有辦法提供了tf.contrib.layers.fully_connected函數的名稱(如一個結果,它保存fully_connected_1/...),但我們鼓勵你切換到tf.layers.dense,這有一個name的論點。無論如何,看看爲什麼這是一個好主意,看看thisthis discussion

+0

感謝@Maxim的回覆。真的很感謝你的時間。 –