2016-11-27 88 views
0

我已經在張量流中創建了一個神經網絡。這個網絡是多標籤的。 Ergo:它試圖預測一個輸入集的多個輸出標籤,在這種情況下爲三個。目前我使用此代碼來測試我的網絡是如何準確地預測了三個標籤:測試張量流網絡:in_top_k()替換多標籤分類

_, indices_1 = tf.nn.top_k(prediction, 3) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
correct = tf.equal(indices_1, indices_2) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

該代碼工作正常。問題是現在我試圖創建代碼來測試index_1中找到的前3個項目是否位於indices_2中的前5個圖像之中。我知道tensorflow有一個in_top_k()方法,但據我所知不接受multilabel。目前我一直在嘗試使用一個for循環來對它們進行比較:

_, indices_1 = tf.nn.top_k(prediction, 5) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
indices_1 = tf.unpack(tf.transpose(indices_1, (1, 0))) 
indices_2 = tf.unpack(tf.transpose(indices_2, (1, 0))) 
correct = [] 
for element in indices_1: 
    for element_2 in indices_2: 
     if element == element_2: 
      correct.append(True) 
     else: 
      correct.append(False) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

然而,這是行不通的。代碼運行但我的準確度始終爲0.0。

所以我有兩個問題之一:

1)有一個簡單的替代in_top_k()接受,我可以使用,而不是自定義編寫代碼多標籤分類?

2)如果不是1:我做錯了什麼導致我得到0.0的準確性?

回答

0

當你

correct = tf.equal(indices_1, indices_2) 

要檢查不只是這兩個指標是否包含相同的元素,但它們是否包含在相同的位置相同的元素。這聽起來不像你想要的。

setdiff1d op會告訴你哪些索引在indices_1中,但不在indices_2中,然後可以用它來計算錯誤。

我認爲過於嚴格的正確性檢查可能是什麼導致你得到一個錯誤的結果。

+0

非常感謝!這是朝着正確方向邁出的一大步。它確實需要我更新我的tensorflow,因爲我的版本還沒有setdiff1d。你介意如何計算錯誤嗎?我已經嘗試了一些東西,但似乎無法弄清楚如何知道setdif1d找到的許多差異。 –