2017-05-10 82 views
1

我在恢復保存的模型時遇到困難。我培養了CNN的MNIST數據集,所有根據MNIST教程Deep MNIST for Experts,我救了我的模型下面的代碼:如何在TensorFlow中導入模型

saver.save(sess, './Tensorflow_MNIST', global_step=max_steps) 

這將創建下列文件:

  • Tensorflow_MNIST- 1000.data 00000-的-00001
  • Tensorflow_MNIST-1000.index
  • Tensorflow_MNIST-1000.meta
  • 關卡

後來我想加載模型,並繼續訓練:

with tf.Session() as sess: 
new_saver = tf.train.import_meta_graph('./Tensorflow_MNIST-1000.meta') 
new_saver.restore(sess, './Tensorflow_MNIST-1000') 

batch_xs, batch_ys = mnist.train.next_batch(50) 
sess.run(train_step, feed_dict[x: batch_xs, y_batch_ys, keep_prob:0.5]) 

然而,這會返回一個錯誤:

NameError: name 'train_step' is not defined 

所以看起來像圖和它的變量和操作未正確加載。我在這裏做錯了什麼?

回答

1

當使用saver.save() TensorFlow保存由Tensors(即TensorFlow的對象)組成的計算圖。

它不會保存您使用的變量。特別是,不會保存所有而不是 a tf.Tensor

您可能希望擁有自己的數據結構來保存其他任何信息。

您可以使用JSON格式以方便使用,甚至可以使用pickle,這在Python中非常簡單,但不能手工編輯。

希望它有助於

0

時節省:

saver = tf.train.Saver(...variables...) 
# Remember the training_op we want to run by adding it to a collection. 
tf.add_to_collection('train_step', train_step) 

恢復時:

with tf.Session() as sess: 
    .... 

    # tf.get_collection() returns a list. get the first one 
    train_step = tf.get_collection('train_step')[0] 
    sess.run(train_step, ....) 

,如果你想重新使用該模型,我想改變sess.run(train_step...)train_step(...)應該工作

0

調用所有帶「」和「0」的張量如import meta_graph所述那樣添加似乎可以做到。因此,例如,計算精度的呼叫變爲:

test_accuracy = sess.run("accuracy:0", feed_dict={"x:0": mnist.test.images, "y_:0": mnist.test.labels, "keep_prob:0": 1.0})