2017-05-07 95 views
0

有兩個python文件,第一個用於保存張量流模型 。第二個是用於恢復保存的模型。如何恢復已保存的張量流模型?

問:

  1. 當我運行了兩個文件,一個接着一個,它的確定。

  2. 當我運行第一個,重新啓動編輯並運行第二個,它 告訴我,w1沒有定義?

我想要做的是:

  1. 節省tensorflow模型

  2. 恢復保存的模型

這有什麼錯呢?感謝您的幫助?

model_save.py

import tensorflow as tf 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1') 
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
saver = tf.train.Saver() 

with tf.Session() as sess: 
sess.run(tf.global_variables_initializer()) 
saver.save(sess, 'SR\\my-model') 

model_restore.py

import tensorflow as tf 

with tf.Session() as sess:  
saver = tf.train.import_meta_graph('SR\\my-model.meta') 
saver.restore(sess,'SR\\my-model') 
print (sess.run(w1)) 

enter image description here

回答

1

簡單地說,你應該使用

print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0'))) 

而不是print (sess.run(w1))在您的model_restore.py文件中。

model_save.py

import tensorflow as tf 
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1') 
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
saver = tf.train.Saver() 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(w1_node.eval()) # [ 0.43350926 1.02784836] 
    #print(w1.eval()) # NameError: name 'w1' is not defined 
    saver.save(sess, 'my-model') 

w1_node僅在model_save.py定義,model_restore.py文件不能識別它。 當我們通過其name調用Tensor變量時,我們應該使用get_tensor_by_name,因爲建議這篇文章Tensorflow: How to get a tensor by name?

model_restore.py

import tensorflow as tf 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('my-model.meta') 
    saver.restore(sess,'my-model') 
    print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0'))) 
    # [ 0.43350926 1.02784836] 
    print(tf.global_variables()) # print tensor variables 
    # [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>, 
    # <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>] 
    for op in tf.get_default_graph().get_operations(): 
    print str(op.name) # print all the operation nodes' name 
+0

感謝您親切的回答。 根據您的建議和相關參考文獻, 我剛剛解決了這個問題:) –

相關問題