2017-05-30 189 views
0

比方說,我有一個尺寸爲[batch_size, 5, 10]的張量,稱爲my_tensor。 我還有一個尺寸爲[batch_size, 1]的另一個張量,其中包含一個名爲selecter的索引。如何過濾基於帶索引張量的張量流張量?

我想對於過濾my_tensorselecter生產規模[batch_size, 10]新張量,即只選擇珍視selecter包含。基本上,它有點減少中間維度(其大小爲5)。我覺得tf.where是正確的選擇,但不確定。 我真的很感謝你的幫助!

回答

1

解決方法是使用tf.gather_nd

tf.gather_nd(
    my_tensor, 
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1)) 

如果你構建selecter是從一開始1-d可以擺脫squeeze的。

+0

這是完美的。非常感謝你! –

+0

你用什麼版本的tensorflow?我有1.3.0和我的tf.gather_nd不接受軸參數。但是,有這個tf.gather。 – omikron

0

替代的解決方案,工作在Tensorflow 1.3:

max_selecter = tf.reduce_max(selecter) + 1 
my_tensor = tf.boolean_mask(
    outputs, 
    tf.logical_xor(
     tf.sequence_mask(my_tensor + 1, max_selecter), 
     tf.sequence_mask(my_tensor, max_selecter) 
    ) 
)