2017-07-10 60 views
0

我想對特定維度上的張量進行排序,並返回指定每個元素的排序索引的相同維度的張量。看來,tf.nn.top_k可以返回排序的索引,但如何將其映射回來?排序張量並返回排序後的索引?

input = [[10, 3, 1], [5, 6, 2], [1, 7, 10]] 
_, indices = tf.nn.top_k(input, k=3, sorted=True) 
indices = [[0, 1, 2], [1, 0, 2], [2, 0, 1]] 

我希望得到的是

reordered = [[0, 1, 2], [1, 0, 2], [2, 1, 0]] 
+0

你是否假設'k'實際上是你的數組的寬度? – user1735003

+0

是的,以便我可以得到與輸入相同形狀的張量來排列所有元素。 – Yang

回答

1

假設indices包含所有指數,即是range(k)置換,可以使用一維張量

tf.map_fn(tf.invert_permutation, indices) 

tf.invert_permutation作品,因此您需要將其包裝到一個map_fn中以將其應用於每行indices)。