2017-03-18 41 views
3

注意:tf.image.non_max_suppression不符合我的要求!Tensorflow非最大抑制

我在嘗試執行類似於Canny edge detector的非最大抑制(NMS)。具體來說,2D陣列上的NMS將保持一個值,如果它是一個窗口內的最大值,否則將其抑制(設置爲0)。

例如,考慮矩陣

[[3 2 1 4 2 3] [1 4 2 1 5 2] [2 2 3 2 1 3]]

如果我們考慮的3 x 3的窗口大小,那麼結果應該是

[[0 0 0 0 0 0] [0 4 0 0 5 0] [0 0 0 0 0 0]]

我已搜索周圍,couldn」在tf.imagetf.nn中找不到執行此操作的任何內容。有代碼執行NMS嗎?如果不是,我如何在Tensorflow(Python)中有效地實現NMS?

謝謝!

編輯:我想出了一個辦法來解決這個,但我不知道是否有更好的方法:取最大池1個跨度(即無下采樣)和窗口大小,然後用tf.where檢查該值等於最大池值,如果不是,則設置爲0。有沒有更好的辦法?

回答

2

回答我的問題(雖然開到更好的解決方案):

import tensorflow as tf 
import numpy as np 

def non_max_suppression(input, window_size): 
    # input: B x W x H x C 
    pooled = tf.nn.max_pool(input, ksize=[1, window_size, window_size, 1], strides=[1,1,1,1], padding='SAME') 
    output = tf.where(tf.equal(input, pooled), input, tf.zeros_like(input)) 

    # NOTE: if input has negative values, the suppressed values can be higher than original 
    return output # output: B X W X H x C 

sess = tf.InteractiveSession() 

x = np.array([[3,2,1,4,2,3],[1,4,2,1,5,2],[2,2,3,2,1,3]], dtype=np.float32).reshape([1,3,6,1]) 
inp = tf.Variable(x) 
out = non_max_suppression(inp, 3) 

sess.run(tf.global_variables_initializer()) 
print out.eval().reshape([3,6]) 
''' 
[[ 0. 0. 0. 0. 0. 0.] 
[ 0. 4. 0. 0. 5. 0.] 
[ 0. 0. 0. 0. 0. 0.]] 
''' 

sess.close()