2017-09-28 56 views
0

此問題與this one相同,但對於Tensorflow。如何從tensorflow張量中的每一行中選擇不同的列?

假設我有「行」的2D張量,並希望從各行選擇的第i個元素和構成這些元件的結果列中,具有在選擇器張量,如下面

import tensorflow as tf 
import numpy as np 

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64) 
# gives 
# array([[ 0, 1, 2], 
#  [ 3, 4, 5], 
#  [ 6, 7, 8], 
#  [ 9, 10, 11], 
#  [12, 13, 14], 
#  [15, 16, 17], 
#  [18, 19, 20], 
#  [21, 22, 23], 
#  [24, 25, 26], 
#  [27, 28, 29]]) 


selector = tf.get_variable("selector", [10,1], dtype=tf.int8, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]])) 

result_of_selection = ... 

# should be 
# array([[ 0], 
#  [ 4], 
#  [ 6], 
#  [11], 
#  [13], 
#  [15], 
#  [18], 
#  [23], 
#  [26], 
#  [28]]) 

我會怎樣做這個?

UPDATE

我寫這樣(感謝@Psidom)

import tensorflow as tf 
import numpy as np 

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64) 

# selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([0, 1, 0, 2, 1, 0, 0, 2, 2, 1], dtype=tf.int32)) 
# selector = tf.expand_dims(selector, axis=1) 
selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]], dtype=tf.int32)) 

ordinals = tf.reshape(tf.range(rows.shape[0]), (-1,1)) 

#idx = tf.concat([selector, ordinals], axis=1) 
idx = tf.stack([selector, ordinals], axis=-1) 

result = tf.gather_nd(rows, idx) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    rows_value, result_value = sess.run([rows, result]) 
    print("rows_value: " + str(rows_value)) 
    print("selector_value: " + str(result_value)) 

,它給了

rows_value: [[ 0. 1. 2.] 
[ 3. 4. 5.] 
[ 6. 7. 8.] 
[ 9. 10. 11.] 
[ 12. 13. 14.] 
[ 15. 16. 17.] 
[ 18. 19. 20.] 
[ 21. 22. 23.] 
[ 24. 25. 26.] 
[ 27. 28. 29.]] 
selector_value: [[ 0.] 
[ 4.] 
[ 2.] 
[ 0.] 
[ 0.] 
[ 0.] 
[ 0.] 
[ 0.] 
[ 0.] 
[ 0.]] 

即不正確。

UPDATE 2

固網

idx = tf.stack([ordinals, selector], axis=-1) 

是不正確的順序。要做到這一點

回答

2

的方法之一是明確通過堆疊可使用tf.rangeselector創建的行索引構建的指數,然後用tf.gather_nd收集物品:

rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64) 
selector = tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]) 

idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1) 

with tf.Session() as sess: 
    print(sess.run(tf.gather_nd(rows, idx))) 

#[[ 0.] 
# [ 4.] 
# [ 6.] 
# [ 11.] 
# [ 13.] 
# [ 15.] 
# [ 18.] 
# [ 23.] 
# [ 26.] 
# [ 28.]] 

這裏idx是原張量中所有元素的實際指數:

with tf.Session() as sess: 
    print(idx.eval()) 
#[[[0 0]] 

# [[1 1]] 

# [[2 0]] 

# [[3 2]] 

# [[4 1]] 

# [[5 0]] 

# [[6 0]] 

# [[7 2]] 

# [[8 2]] 

# [[9 1]]] 

編輯selector作爲變量:

selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]])) 
idx = tf.stack([tf.reshape(tf.range(rows.shape[0]), (-1,1)), selector], axis=-1) 

with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    print(sess.run(tf.gather_nd(rows, idx))) 

#[[ 0.] 
# [ 4.] 
# [ 6.] 
# [ 11.] 
# [ 13.] 
# [ 15.] 
# [ 18.] 
# [ 23.] 
# [ 26.] 
# [ 28.]] 
+0

如果'selector'不是恆定的,並且可以沿途改變,將這項工作? – Dims

+0

您是否使用'stack'而不是'concat'來添加額外的單一維度? – Dims

+0

它應該工作得很好,只要記住初始化變量。 – Psidom

相關問題