2017-06-02 186 views
0

我知道以下是已經回答的問題,但即使我嘗試了所有建議的解決方案,但都沒有解決我的問題。 我在MNIST數據集上進行了訓練。開始的時候它更深入,但爲了專注於我簡化它的問題。如何在Tensorflow中使用預訓練模型?

mnist = mnist_data.read_data_sets('MNIST_data', one_hot=True) 

# train the net 
def train(): 
    for i in range(1000): 
     batch_xs, batch_ys = mnist.train.next_batch(100) 
     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 
     print("accuracy", sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 
     if i%100==0: 
      save_path = saver.save(sess, "./tmp/model.ckpt", global_step = i, write_meta_graph=True)  
      print("Model saved in file: %s" % save_path) 

# evaluate the net 
def test(image, label): 
    true_value = tf.argmax(label, 1) 
    prediction = tf.argmax(y, 1) 
    print("true value:", sess.run(true_value)) 
    print("predictions", sess.run(prediction, feed_dict={x:image})) 

sess = tf.InteractiveSession() 

x = tf.placeholder("float", shape=[None, 784]) 
W = tf.Variable(tf.zeros([784,10]), name = "W1") 
b = tf.Variable(tf.zeros([10]), name = "B1") 
y = tf.nn.softmax(tf.matmul(x,W) + b, name ="Y") 
y_ = tf.placeholder("float", shape=[None, 10]) 
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) 
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 

saver = tf.train.Saver() 
model_to_restore="./tmp/model.ckpt-100.meta" 
if os.path.isfile(model_to_restore): 
    #what i have to do here?????# 
else: 
#this part works!# 
    print("Model does not exist: training") 
    train() 

謝謝大家的答案!

問候,

西爾維奧

UPDATE

  • 我都嘗試

    saver.restore(sess, model_to_restore) 
    

    saver = tf.train.import_meta_graph(model_to_restore) 
    saver.restore(sess, model_to_restore) 
    

    但在這兩種情況下,我有這個錯誤來自終端:

    DataLossError (see above for traceback): Unable to open table  file ./tmp/model.ckpt.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator? 
    [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]] 
    

回答

0

我覺得你的位置到模型可能是錯誤的,我勸你放棄以下工作流程一試。

由於保存的模型包括幾個文件,我通常將它們保存到一個文件夾訓練後:

modelPath = "myMNIST/model" 
saved_path = saver.save(sess, os.path.join(modelPath, "model.ckpt")) 
print("Model saved in file: ", saved_path) 

這也將告訴你它已經保存的確切位置。

然後我可以開始我的預測的保存位置內(cd到myMNIST)和恢復模式:

ckpt = tf.train.get_checkpoint_state("./model") 
if ckpt and ckpt.model_checkpoint_path: 
    print("Restored Model") 
    saver.restore(sess, ckpt.model_checkpoint_path) 
else: 
    print("Could not restore model!") 
+0

的確!那是錯誤! 非常感謝您的回覆! – SilvioBarra

相關問題