2017-08-05 51 views
0

掩蓋argmax每一行我有形狀NX7的張量,這看起來是這樣的:獲得不同的面具TensorFlow

[0.97863993 0.64479575 -0.202357 0.94678476 0.0080051 0.44507797 0.47864 
0.05914348 -0.72649432 0.193803 0.47295245 0.8381458 0.30449861 0.46783] 

我也有同樣的形狀,這是一個布爾值面具的另一張量:

[True False True True False True False 
False True False False True False False] 

我想在第一張每行的argmax,但僅限於這面具是真,以下數組所以基本上argmax這些元素:

[0.97863993 X   -0.202357 0.94678476 X   0.44507797 X 
X   -0.72649432 X   X   0.8381458 X   X] 

哪些應該因此而成爲:

[0 
4] 

這是可能的TensorFlow?我試圖找出與tf.boolean_mask,但我不知道如何處理掩碼中具有不同數量的True值的不同行。

輸入代碼TF:

mask = tf.placeholder(shape=[None, 7], dtype=tf.bool) 
val = tf.placeholder(shape=[None, 7], dtype=tf.float32) 

arg_max = ??? 

注意,我想負值正確處理,以及(另有Ishant Mrinal提出的方法將工作)。

回答

0

轉換布爾數組轉換成一個int數組

# mask = tf.placeholder(shape=[None, 7], dtype=tf.bool) 
# mask = tf.cast(mask, dtype=tf.float32) 
mask = tf.placeholder(shape=[None, 7], dtype=tf.float32) 
val = tf.placeholder(shape=[None, 7], dtype=tf.float32) 
argmax = tf.argmax(tf.multiply(val, mask), axis=1) 
sess.run(argmax, {val: your_val_array, mask: 2*mask_bool_array.astype(float)-1 }) 
+0

嗨Ishant,感謝您的建議。如果val vector中不存在唯一的負面情況,這將起作用,但我不能認爲在我的情況下不幸。 –

+0

我認爲你可以使它工作,在掩碼中使用-1作爲假值。 –

+0

我編輯了使用val的負值的答案 –

0

你可以嘗試像

arg_max = tf.argmax(tf.minimum(val, (2* tf.to_float(mask) - 1) * np.inf), axis=1) 
相關問題