2017-09-01 117 views
1

它來自Udacity深度學習基礎課程。它似乎爲他們工作。但它在我的電腦中不起作用。請看一看。感謝您的幫助!Tensorflow保存和恢復變量不一樣

講座和我的電腦的tensorflow版本都是1.0.0。

import tensorflow as tf 

# The file path to save the data 
save_file = './model.ckpt' 

# Two Tensor Variables: weights and bias 
weights = tf.Variable(tf.truncated_normal([2, 3])) 
bias = tf.Variable(tf.truncated_normal([3])) 

# Class used to save and/or restore Tensor Variables 
saver = tf.train.Saver() 

with tf.Session() as sess: 
    # Initialize all the Variables 
    sess.run(tf.global_variables_initializer()) 

    # Show the values of weights and bias 
    print('Weights:') 
    print(sess.run(weights)) 
    print('Bias:') 
    print(sess.run(bias)) 

    # Save the model 
    saver.save(sess, save_file) 

# Remove the previous weights and bias 
tf.reset_default_graph() 

# Two Variables: weights and bias 
weights = tf.Variable(tf.truncated_normal([2, 3])) 
bias = tf.Variable(tf.truncated_normal([3])) 

# Class used to save and/or restore Tensor Variables 
saver = tf.train.Saver() 

with tf.Session() as sess: 
    # Load the weights and bias 
    saver.restore(sess, save_file) 

    # Show the values of weights and bias 
    print('Weight:') 
    print(sess.run(weights)) 
    print('Bias:') 
    print(sess.run(bias)) 
+0

你的代碼適用於我的最新TensorFlow。你能否更新到最新的TensorFlow版本並重試?如果它仍然不起作用,會出現什麼問題?你有錯誤還是錯誤的輸出? –

回答

0

輸入張量流後插入tf.reset_default_graph()

0

我在1.1.0中運行了你的代碼,結果是一樣的...