我用tf.gather_nd
工作了問題1。
輸入是:
x
:你的張量T
從中[n1, n2, n3, n4]
idx
以提取數值,形狀:在您想要從T
中提取的索引,形狀爲[n1, n2]
和含值從0
到n3 - 1
結果是:
res
:的T
爲idx
每個指數之所提取的值,形狀[n1, n2, n4]
的作爲tf.gather_nd()
期望你可以創建整個指數來檢索x
(例如[1, 0, 4, 1]
),我們必須先在indices_base
中創建它。
indices
需要是形狀res + R
即[n1, n2, n4, R]
其中R=4
是張量x
的秩的論點。
# Inputs:
n1 = 2
n2 = 3
n3 = 5
n4 = 2
x = tf.reshape(tf.range(n1*n2*n3*n4), [n1, n2, n3, n4]) # range(60) reshaped
idx = tf.constant([[1, 0, 2], [4, 3, 3]]) # shape [n1, n2]
range_n1 = tf.reshape(tf.range(n1), [n1, 1, 1, 1])
indices_base_1 = tf.tile(range_n1, [1, n2, n4, 1])
range_n2 = tf.reshape(tf.range(n2), [1, n2, 1, 1])
indices_base_2 = tf.tile(range_n2, [n1, 1, n4, 1])
range_n4 = tf.reshape(tf.range(n4), [1, 1, n4, 1])
indices_base_4 = tf.tile(range_n4, [n1, n2, 1, 1])
idx = tf.reshape(idx, [n1, n2, 1, 1])
idx = tf.tile(idx, [1, 1, n4, 1])
# Create the big indices needed of shape [n1, n2, n3, n4]
indices = tf.concat(3, [indices_base_1, indices_base_2, idx, indices_base_4])
# Finally we can apply tf.gather_nd
res = tf.gather_nd(x, indices)
無論如何,這是相當複雜的,我不知道這是否能產生良好的性能。
P.S:您應該在單獨的帖子中發佈問題2。