2016-07-20 107 views
2

我有一個形狀爲[batch_size, D]的二維張量A,以及形狀爲[batch_size]的一維張量BB的每個元素是A的列索引,對於A的每一行,例如。 B[i] in [0,D)Tensorflow索引到具有1d張量的2d張量

什麼是tensorflow得到的值A[B]

例如最好的辦法:

A = tf.constant([[0,1,2], 
       [3,4,5]]) 
B = tf.constant([2,1]) 

與所需的輸出:

some_slice_func(A, B) -> [2,4] 

還有另一種約束。實際上,batch_size實際上是None

在此先感謝!

回答

3

我能得到它的工作使用線性指標:

def vector_slice(A, B): 
    """ Returns values of rows i of A at column B[i] 

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D) 

    Example: 
     A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4] 
      [3,4]] 
    """ 
    linear_index = (tf.shape(A)[1] 
        * tf.range(0,tf.shape(A)[0])) 
    linear_A = tf.reshape(A, [-1]) 
    return tf.gather(linear_A, B + linear_index) 

這種感覺稍微哈克雖然。

如果有人知道更好(如更清晰或更快),也請留下一個答案! (我不會接受我自己的一段時間)

0

最簡單的方法可能是通過連接範圍(batch_size)和B來構建適當的2d索引,以獲得batch_size x 2矩陣。然後將其傳遞給tf.gather_nd。

0

代碼什麼@Eugene Brevdo說:

def vector_slice(A, B): 
    """ Returns values of rows i of A at column B[i] 

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D) 

    Example: 
     A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4] 
      [3,4]] 
    """ 
    B = tf.expand_dims(B, 1) 
    range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1) 
    ind = tf.concat([range, B], 1) 
    return tf.gather_nd(A, ind)