2017-09-17 60 views
0

我在tf中編寫此操作時遇到了一些問題。下面是一個例子,假設我有一個[n,2]佔位符x和一個[n,1]佔位符y。 x = [[1,2],[3,4],[5,6]] y = [1,0,1] 對於y中的每個元素i我想從ith 2d中獲取相應的元素張量。 在這個例子中,輸出應該是[2,3,6]。我嘗試了幾種技術,但沒有成功。是否有一種簡單的方法來做到這一點tensorflow?Tensorflow根據另一個佔位符從佔位符中獲取元素

謝謝

回答

0

二者必選其一tf.gather_ndtf.stacktf.where手動攻擊它:

import tensorflow as tf 

x = tf.convert_to_tensor([[1, 2], [3, 4], [5, 6]]) 
y = tf.convert_to_tensor([1, 0, 1]) 

with tf.Session() as sess: 
    xx = tf.unstack(x, axis=1) 
    ans = tf.where(tf.equal(y, tf.zeros_like(y)), xx[0], xx[1]) 
    print sess.run(ans) 


with tf.Session() as sess: 
    idx = tf.range(0, limit=3, delta=1, name='arange') 
    idx = tf.stack([idx, y], axis=-1) 
    ans = tf.gather_nd(x, idx) 
    print sess.run(ans)