你想要的就像是一個自定義的減少功能。如果要保持像最大值的指數indices
話,我會建議使用tf.reduce_max
:
max_params = tf.reduce_max(params_tensor, reduction_indices=[2])
否則,這裏是讓你想要什麼(張量對象是不可轉讓的,所以我們創建的二維表的一種方式張量和它使用tf.pack
)包裝:
import tensorflow as tf
import numpy as np
with tf.Graph().as_default():
params_tensor = tf.pack(np.random.randint(1,256, [5,5,10]).astype(np.int32))
indices = tf.pack(np.random.randint(1,10,[5,5]).astype(np.int32))
output = [ [None for j in range(params_tensor.get_shape()[1])] for i in range(params_tensor.get_shape()[0])]
for i in range(params_tensor.get_shape()[0]):
for j in range(params_tensor.get_shape()[1]):
output[i][j] = params_tensor[i,j,indices[i,j]]
output = tf.pack(output)
with tf.Session() as sess:
params_tensor,indices,output = sess.run([params_tensor,indices,output])
print params_tensor
print indices
print output
來源
2017-03-05 00:54:40
Ali
謝謝。這就是我想要的 –
對於我來說,將參數平滑爲(-1,256)和索引爲(-1)似乎更容易些,使用gather_nd然後解開平鋪,而不是平鋪 – Maxim
geez。這真的很聰明,也@ @ @ $。煙霧從我頭上讀出來。它也很有用!謝謝。爲我節省了很多時間。 – thang