我想檢查批次的偶數和奇數元素,並在需要時進行交換。我設法引起兩個張量我想交織:Tensorflow:根據偶數和奇數索引合併兩個二維張量
def tf_oplu(x, name=None):
even = x[:,::2] #slicing into odd and even parts on the batch
odd = x[:,1::2]
even_flatten = tf.reshape(even, [-1]) # flatten tensors
#in row-major order to apply function across them
odd_flatten = tf.reshape(odd, [-1])
compare = tf.to_float(even_flatten<odd_flatten)
compare_not = tf.to_float(even_flatten>=odd_flatten)
#def oplu(x,y): # trivial function
# if x<y : # (x<y)==1
# return y, x
# else:
# return x, y # (x<y)==0
even_flatten_new = odd_flatten * compare + even_flatten * compare_not
odd_flatten_new = odd_flatten * compare_not + even_flatten * compare
# convolute back
even_new = tf.reshape(even_flatten_new,[100,128])
odd_new = tf.reshape(odd_flatten_new,[100,128])
現在我想回去$ [100256] $張量奇數和偶數的地方填補。在numpy的我當然會做的事:
y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new
return y
但這樣的事情是不可能的tensoflow,因爲張是不可修改的。我想這是可能的sparse tensor或tf.gather_nd,但都需要生成索引數組,這對我來說也是不平凡的任務。 還有一點需要注意:我不想通過tf.py_func
使用任何python函數,因爲我檢查到它們只在CPU上運行。也許lambda和tf.map_fn
可能會有所幫助?謝謝!
非常感謝!但是這個代碼產生張量[200,128]而不是[100,256]。我已將其更改爲 'y = tf.reshape(tf.stack([even_new,odd_new],axis = 0),[tf.shape(even_new)[0], - 1])'所以輸出爲預期。你能否給我一點解釋,爲什麼它確實在需要的地方放置了奇怪的元素? – Slowpoke
到目前爲止,據我瞭解,它將它們垂直堆疊,然後進行重新塑形,將水平放置在彼此之下的元素放置在一起。 – Slowpoke
是的,我遵循你的numpy例子,它也垂直疊加張量(沿着第一維)。你修改以水平堆疊它們是正確的。 – user1735003