2017-06-12 23 views
1

我是tensorflow的新手。我已經爲mnist圖像分類構建了一個convonet,我正在使用隊列從磁盤批量中讀取圖像(png)並將它傳遞給訓練op(我現在對此很舒服)直到火車和我正在評估我在訓練期間的某些步驟的準確性。如何使用隊列方式(不含feed_dict)#tensorflow在保存的模型上使用測試數據?

我使用Saver對象保存模型,並可以看到正在磁盤上寫入的元和檢查點文件。

現在真正的挑戰是恢復模型一旦完成培訓,並用它來對新的圖像預測

一個在我圖的第一步(訓練)是像下面這需要x_image(圖片從列車隊列中)h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1)+ b_conv1)

因爲我沒有使用Feed字典方法,所以我不能只使用保存程序恢復準確性並傳遞新數據。我必須爲測試數據定義隊列並重新構建圖形(與之前完全一樣),並將引用x_image更改爲指向測試數據隊列。

如何在訓練時恢復學習權重並將其用於此新圖形以簡單運行我的預測/準確性操作。

我試圖按照 - https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py教程,但與eval代碼丟失。

此外,如果我在訓練圖中添加虛擬常量,然後嘗試檢索它的值,我可以檢索它。

任何1請幫助。謝謝

+0

我能夠使用saver.restore(),並恢復圖的變量。 小心我沒有運行tf.global_variables_initializer(),以便變量/權重不會重新初始化,而是從保存的模型中恢復。 我現在觀察到的唯一奇怪的事情是,我的預測操作爲相同的輸入圖像返回不同的標籤。我正在使用tf.train.shuffle_batch()來生成測試樣本。任何人都可以指出我的錯誤。 謝謝 – tejas

回答

2

好的,所以我找到了答案。

最初的挑戰是在訓練和驗證階段使用隊列時在列車和測試數據之間切換。 現在,隊列是圖結構的一部分,我們不能簡單地修改它們。

我發現一篇文章使用tf.case在列車和測試隊列之間進行切換,但我無法使用隨機的隨機播放批處理。

手頭真正的任務是保存模型崗位培訓並使用保存的模型預測生產情況。

因此,這裏的流量:

培訓

  • 創建一個用於創建您的圖表的方法(將圖像張量 輸入)。
  • 通過傳遞訓練圖像批次來構建訓練圖形
  • 執行訓練並使用保存對象保存模型。

評價

  • 立即重建與測試圖像批次相同的曲線圖。
  • 在會話中使用的保護對象,以恢復權重(注意你不需要通過恢復哪些變量,默認情況下它會還原所有恢復能變量) 不要在這個時候運行gloabl變量初始化
  • 運行您預測運算(距離新建成的圖表生成)

還要確保您關閉在EVAL下拉輸出功能,因爲它會保持改變輸出對於相同的輸入

下面是僞

train_op, y_predict, accuracy = create_graph(train_input, train_label) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    model_saver = tf.train.Saver() 

    for i in range(2000): 
     if i%100 == 0: 
      train_accuracy = sess.run(accuracy) 
      print("step %d, training accuracy %f" %(i, train_accuracy)) 
     sess.run(train_op) 

    print(sess.run(accuracy)) 
    model_saver.save(sess, 'model/simple_model', global_step=100) 
    coord.request_stop() 
    coord.join(threads) 

對於評價

_, y_predict, accuracy = create_graph(test_input, test_label) 

saver = tf.train.Saver() 

with tf.Session() as sess: 
    saver.restore(sess, tf.train.latest_checkpoint("./model/")) 

    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    label_predict = sess.run([y_predict]) 
相關問題