我有一個形狀爲[batch_size, D]
的二維張量A
,以及形狀爲[batch_size]
的一維張量B
。 B
的每個元素是A
的列索引,對於A
的每一行,例如。 B[i] in [0,D)
。Tensorflow索引到具有1d張量的2d張量
什麼是tensorflow得到的值A[B]
例如最好的辦法:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
與所需的輸出:
some_slice_func(A, B) -> [2,4]
還有另一種約束。實際上,batch_size
實際上是None
。
在此先感謝!