2016-03-07 120 views
8

我有一個關於如何在TensorFlow中建立索引的基本問題。TensorFlow:使用張量來索引另一個張量

在numpy的:

x = np.asarray([1,2,3,3,2,5,6,7,1,3]) 
e = np.asarray([0,1,0,1,1,1,0,1]) 
#numpy 
print x * e[x] 

我能得到

[1 0 3 3 0 5 0 7 1 3] 

我怎樣才能做到這一點TensorFlow?

x = np.asarray([1,2,3,3,2,5,6,7,1,3]) 
e = np.asarray([0,1,0,1,1,1,0,1]) 
x_t = tf.constant(x) 
e_t = tf.constant(e) 
with tf.Session(): 
    ???? 

謝謝!

+0

http://stackoverflow.com/questions/33736795/tensorflow-numpy-like-tensor-indexing?rq=1 是不是你在問什麼? – Alleo

回答

19

幸運的是,你問有關tf.gather()在TensorFlow支持的確切情況:

result = x_t * tf.gather(e_t, x_t) 

with tf.Session() as sess: 
    print sess.run(result) # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])' 

tf.gather()運算比NumPy's advanced indexing那麼強大:它僅支持其零維提取的張全片。已請求支持更一般的索引,並且正在跟蹤this GitHub issue

+1

非常感謝! – user200340

+3

Tensorflow現在有一個更強大的'tf.gather_nd()'操作。 – fritzo