2016-06-22 66 views
0

應用功能我有兩個問題:張量切片並通過張量

  1. 讓張T具有形狀[n1, n2, n3, n4]。讓另一個張量IDX的形狀[n1, n2]類型int包含所需的指標。如何獲得形狀[n1, n2, n4]的張量,其中我只想從n3 dim的T中提取指數,這些指數在IDX中指定。簡單的例子:

    x = [[[2, 3, 1, 2, 5], 
        [7, 1, 5, 6, 0], 
        [7, 8, 1, 3, 8]], 
        [[0, 7, 7, 6, 9], 
        [5, 6, 7, 8, 8], 
        [2, 3, 2, 9, 6]]] 
    idx = [[1, 0, 2], 
        [4, 3, 3]] 
    res = [[3, 7, 1], 
        [9, 8, 9]]` 
    
  2. 給定一個函數,它接受1D張量FUNC(X,Y)如何可以將它應用到4D張量X,Y在過去的尺寸,即結果 - 對於所有i,j,k,結果[i,j,k] = f(X [i,j,k,:,Y [i,j,k,:])的3D張量。我發現tf.py_func,但無法得到如何使用它在我的情況。

在此先感謝您的幫助!

回答

1

我用tf.gather_nd工作了問題1。

輸入是:

  • x:你的張量T從中[n1, n2, n3, n4]
    • 我用更清晰值從0到size(T)
  • idx以提取數值,形狀:在您想要從T中提取的索引,形狀爲[n1, n2]和含值從0n3 - 1

結果是:

  • res:的Tidx每個指數之所提取的值,形狀[n1, n2, n4]

的作爲tf.gather_nd()期望你可以創建整個指數來檢索x(例如[1, 0, 4, 1]),我們必須先在indices_base中創建它。

indices需要是形狀res + R[n1, n2, n4, R]其中R=4是張量x的秩的論點。

# Inputs: 
n1 = 2 
n2 = 3 
n3 = 5 
n4 = 2 
x = tf.reshape(tf.range(n1*n2*n3*n4), [n1, n2, n3, n4]) # range(60) reshaped 
idx = tf.constant([[1, 0, 2], [4, 3, 3]]) # shape [n1, n2] 

range_n1 = tf.reshape(tf.range(n1), [n1, 1, 1, 1]) 
indices_base_1 = tf.tile(range_n1, [1, n2, n4, 1]) 

range_n2 = tf.reshape(tf.range(n2), [1, n2, 1, 1]) 
indices_base_2 = tf.tile(range_n2, [n1, 1, n4, 1]) 

range_n4 = tf.reshape(tf.range(n4), [1, 1, n4, 1]) 
indices_base_4 = tf.tile(range_n4, [n1, n2, 1, 1]) 

idx = tf.reshape(idx, [n1, n2, 1, 1]) 
idx = tf.tile(idx, [1, 1, n4, 1]) 

# Create the big indices needed of shape [n1, n2, n3, n4] 
indices = tf.concat(3, [indices_base_1, indices_base_2, idx, indices_base_4]) 

# Finally we can apply tf.gather_nd 
res = tf.gather_nd(x, indices) 

無論如何,這是相當複雜的,我不知道這是否能產生良好的性能。

P.S:您應該在單獨的帖子中發佈問題2。