您需要爲您的slice
張量添加一個尺寸,您可以使用tf.pack
來做尺寸,然後我們可以使用tf.gather_nd
沒問題。
import tensorflow as tf
tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
old_slice = tf.constant([[0, 0, 1], [0, 1, 2]])
# We need to add a dimension - we need a tensor of rank 2, 3, 2 instead of 2, 3
dims = tf.constant([[0, 0, 0], [1, 1, 1]])
new_slice = tf.pack([dims, old_slice], 2)
out = tf.gather_nd(tensor, new_slice)
如果我們運行如下代碼:
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
run_tensor, run_slice, run_out = sess.run([tensor, new_slice, out])
print 'Input tensor:'
print run_tensor
print 'Correct param for gather_nd:'
print run_slice
print 'Output:'
print run_out
這應該給正確的輸出:
Input tensor:
[[1 2 3]
[4 5 6]]
Correct param for gather_nd:
[[[0 0]
[0 0]
[0 1]]
[[1 0]
[1 1]
[1 2]]]
Output:
[[1 1 2]
[4 5 6]]