2017-02-12 143 views
3

我正在尋找一種類似於Python的list.index()函數的TensorFlow方法。如何查找TensorFlow中第一個匹配元素的索引

給出一個矩陣和一個值來查找,我想知道在矩陣的每一行中值的第一次出現。

例如,

m is a <batch_size, 100> matrix of integers 
val = 23 

result = [0] * batch_size 
for i, row_elems in enumerate(m): 
    result[i] = row_elems.index(val) 

我不能假設「VAL」只出現在每行中一次,否則我會使用tf.argmax(米== VAL)已經實現它。在我的情況下,重要的是要獲得第一個發生'val'的索引,而不是任何。

回答

6

看起來tf.argmax的工作原理類似np.argmax(根據the test ),當有多個最大值出現時,它將返回第一個索引。 你可以使用tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)來得到你想要的。但是,目前tf.argmax的行爲在最大值出現多次的情況下未定義。

如果您擔心未定義的行爲,您可以按照@Igor Tsvetkov的建議在tf.where的返回值上應用tf.argmin。 例如,

# test with tensorflow r1.0 
import tensorflow as tf 

val = 3 
m = tf.placeholder(tf.int32) 
m_feed = [[0 , 0, val, 0, val], 
      [val, 0, val, val, 0], 
      [0 , val, 0, 0, 0]] 

tmp_indices = tf.where(tf.equal(m, val)) 
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0]) 

with tf.Session() as sess: 
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1] 

注意tf.segment_min將提高InvalidArgumentError時,有不包含val一些列。 row_elems.index(val)row_elems不包含val時也會引發異常。

+0

這非常有幫助!如果我們想將val更新爲new_val,該怎麼辦?我在這裏問這個問題:https://stackoverflow.com/questions/45684445/tensorflow-update-first-matching-element-in-each-row – reese0106

1

看起來有點醜,但工程(假設mval都是張量):

idx = list() 
for t in tf.unpack(m, axis=0): 
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val)))) 
idx = tf.pack(idx, axis=0) 

編輯: 作爲Yaroslav Bulatov提到的,你可以實現與tf.map_fn相同的結果:

def index1d(t): 
    return tf.reduce_min(tf.where(tf.equal(t, val))) 

idx = tf.map_fn(index1d, m, dtype=tf.int64) 
+1

'map_fn'無需解包就可以做到這一點 –

相關問題