2017-05-06 64 views
0

我正在使用tensorflow來訓練和使用一個小型神經網絡(2d分類有兩個類),但我有一個非常奇怪的問題,並且看不到我做錯了什麼: 當我繪製the predictions vs the true labels只有測試批次的準確度爲1,我顯然有一些錯誤分類的樣本。在我看來,tf.argmax是錯誤地評估爲1的準確性的責任,但顯然這不是真正的原因。 不管怎麼說,我得出這樣的結論,通過計算從最後一層輸出精度:Tensorflow argmax給出錯誤指示

with tf.name_scope('accuracy'): 
    plabel = tf.argmax(y, 1) # vector of predicted label, elem {0,1}^batch_size 
    tlabel = tf.argmax(y_, 1) # similar vector of true labels 
    correct_predictions = tf.cast(tf.equal(plabel, tlabel), tf.float32) 
... 
with tf.Session() as sess: 
    batchx, batchy = generate_batch() 
    predictions, acc = sess.run([y, accuracy], feed_dict={x: batchx, y_: batchy}) 
    mistakes = 0. 
    for j in range(batch_size): 
     if (predictions[j, 0] - predictions[j, 1])*(batchy[j, 0] - batchy[j, 1]) < 0: 
      print("mistake: ", predictions[j], batchy[j]) 
      mistakes += 1./batch_size 
    print("Acc = {}/{} = 1-m/b".format(acc, 1. - mistakes)) 

x和Y_是輸入張量,y是最後一層和模型已經被訓練。

它給了我下面的輸出:

Acc = 1.0/0.86 = 1-m/b 

這些值應該是相同的。

該圖還表明真正精度不1,或所評估的準確性張量不屬於相同的運行作爲預測(Y)。

我沒有發現任何暗示tf.argmax真的是問題,感到很絕望。所以,在此先感謝您的幫助

回答

0

看來你的準確度功能是正確的,因爲你沒有張貼整個你的代碼,我建議你計算會話內的精度,所以每次可以打印預測和真實的內容標籤並追蹤執行情況。

得到像predictions = sess.run(y, feed_dict={x: batchx, y_: batchy})預測然後通過print predictions.eval()打印預言,然後用plabel = tf.argmax(predictions, 1),再次print plabel.eval(),也print batchyprint (tf.argmax(batchy)).eval()

現在你應該知道發生了什麼,每個print聲明走錯了。

+0

我修改這樣的代碼,並刪除所有調試輸出,僅打印你所期望之後,這仍然是: 'ACC,預測,校正,p,T = sess.run([精度,Y,correct_predictions,PLABEL,的TLabel],feed_dict = F)' - '在範圍Ĵ(的batch_size):打印(「這應該是1: 」,FDICT [Y _] [J,T [j]的])' - 但有時它打印0 – nielsrolf

+0

好這只是一個錯字,在我的原代碼中,我使用的錯誤的飼料字典。對不起,謝謝 – nielsrolf

0

對於精度:

def evaluate(X_data, y_data): 
    num_examples = len(X_data) 
    total_accuracy = 0 
    sess = tf.get_default_session() 
    for offset in range(0, num_examples, BATCH_SIZE): 
     batch_x, batch_y = X_data[offset:offset+BATCH_SIZE], y_data[offset:offset+BATCH_SIZE] 
     accuracy = sess.run(accuracy_operation, feed_dict={x: batch_x, y: batch_y}) 
     total_accuracy += (accuracy * len(batch_x)) 
    return total_accuracy/num_examples 

用於獲取預測:

prediction = sess.run(tf.argmax(logits, 1), feed_dict={x: train_data})