2016-03-03 48 views
0

我正在訓練一個簡單的網絡來預測單個對象的邊界框座標。但是有圖片沒有找到任何物體。由於網絡總是進行預測,因此它也預測0和1之間的置信度值,這應該表示圖片中存在對象的概率。我的張量與預測被稱爲logits,它是一個(batch_size, 5)張量(置信度,x,y,寬度和高度)。同樣,labels張量也是(batch_size, 5)Tensorflow如何檢查張量行是否僅爲零?

以前我是訓練只與始終有一個物體的圖像,這樣我就可以基本上做到

loss = tf.l2_loss(logits - labels)

我也想開始訓練,沒有對象和圖片時沒有對象的圖片,我不希望網絡因其預測的座標而受到懲罰。在這種情況下,所有重要的是置信度值,它應該接近0(沒有對象)。

我應該如何構建我的標籤和損失函數來完成此操作?我可以將沒有對象的圖像標籤設置爲全零,但是如何檢查特定行只有零?在這種情況下,logits中的對應行也需要設置爲零(置信度值!除外),以便由於座標而產生的損失也爲零。

回答

1

的一種方法找到,如果張量的一排都是零:

import tensorflow as tf 

image = tf.fill([8,8], 0) 
sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
image_row = tf.slice(image, [1,0], [1, -1]) 
total = tf.reduce_sum(tf.abs(image_row)) 
is_all_zero = tf.equal(total, 0) 
print sess.run([total, is_all_zero, image_row]) 

結果:

[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)] 
+0

這解決了OP的問題,但實際上並沒有檢查該行是全零,但總和爲零。一行可能包含'[100,-100,0,0,0]',這將加和爲零。如果你改變了一些東西,那麼它可以在其他情況下一致地工作,並更精確地匹配OP問題的標題。 –

+0

先取絕對值 – zenna

+0

謝謝,增加了abs運算符來修復此案例。 – fabrizioM