掩蓋argmax每一行我有形狀NX7的張量,這看起來是這樣的:獲得不同的面具TensorFlow
[0.97863993 0.64479575 -0.202357 0.94678476 0.0080051 0.44507797 0.47864
0.05914348 -0.72649432 0.193803 0.47295245 0.8381458 0.30449861 0.46783]
我也有同樣的形狀,這是一個布爾值面具的另一張量:
[True False True True False True False
False True False False True False False]
我想在第一張每行的argmax,但僅限於這面具是真,以下數組所以基本上argmax這些元素:
[0.97863993 X -0.202357 0.94678476 X 0.44507797 X
X -0.72649432 X X 0.8381458 X X]
哪些應該因此而成爲:
[0
4]
這是可能的TensorFlow?我試圖找出與tf.boolean_mask
,但我不知道如何處理掩碼中具有不同數量的True
值的不同行。
輸入代碼TF:
mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)
arg_max = ???
注意,我想負值正確處理,以及(另有Ishant Mrinal提出的方法將工作)。
嗨Ishant,感謝您的建議。如果val vector中不存在唯一的負面情況,這將起作用,但我不能認爲在我的情況下不幸。 –
我認爲你可以使它工作,在掩碼中使用-1作爲假值。 –
我編輯了使用val的負值的答案 –