2017-07-06 125 views
0

我想檢查批次的偶數和奇數元素,並在需要時進行交換。我設法引起兩個張量我想交織:Tensorflow:根據偶數和奇數索引合併兩個二維張量

def tf_oplu(x, name=None): 


    even = x[:,::2] #slicing into odd and even parts on the batch 
    odd = x[:,1::2] 

    even_flatten = tf.reshape(even, [-1]) # flatten tensors 
    #in row-major order to apply function across them 

    odd_flatten = tf.reshape(odd, [-1]) 

    compare = tf.to_float(even_flatten<odd_flatten) 
    compare_not = tf.to_float(even_flatten>=odd_flatten) 

    #def oplu(x,y): # trivial function 
    # if x<y : # (x<y)==1 
    #  return y, x 
    # else: 
    #  return x, y # (x<y)==0 

    even_flatten_new = odd_flatten * compare + even_flatten * compare_not 
    odd_flatten_new = odd_flatten * compare_not + even_flatten * compare 

    # convolute back 

    even_new = tf.reshape(even_flatten_new,[100,128]) 
    odd_new = tf.reshape(odd_flatten_new,[100,128]) 

現在我想回去$ [100256] $張量奇數和偶數的地方填補。在numpy的我當然會做的事:

y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype) 
y[:,0::2] = even_new 
y[:,1::2] = odd_new 

return y 

但這樣的事情是不可能的tensoflow,因爲張是不可修改的。我想這是可能的sparse tensortf.gather_nd,但都需要生成索引數組,這對我來說也是不平凡的任務。 還有一點需要注意:我不想通過tf.py_func使用任何python函數,因爲我檢查到它們只在CPU上運行。也許lambda和tf.map_fn可能會有所幫助?謝謝!

回答

1

要垂直交錯兩個矩陣,您不要使用大號的槍,如gathermap_fn。你可以簡單地交錯它們如下:

tf.reshape(
    tf.stack([even_new, odd_new], axis=1), 
    [-1, tf.shape(even_new)[1]]) 

編輯

要水平交錯他們:

tf.reshape(
    tf.concat([even_new[...,tf.newaxis], odd_new[...,tf.newaxis]], axis=-1), 
    [tf.shape(even_new)[0],-1]) 

的想法是使用堆棧來交織它們在內存中。發生堆棧的維度給出了交織的粒度。如果我們在axis=0處堆疊,則交錯發生在每個元素處,混合列。如果我們堆疊在axis=1上,則整個輸入行保持連續,在行之間發生交錯。

+0

非常感謝!但是這個代碼產生張量[200,128]而不是[100,256]。我已將其更改爲 'y = tf.reshape(tf.stack([even_new,odd_new],axis = 0),[tf.shape(even_new)[0], - 1])'所以輸出爲預期。你能否給我一點解釋,爲什麼它確實在需要的地方放置了奇怪的元素? – Slowpoke

+0

到目前爲止,據我瞭解,它將它們垂直堆疊,然後進行重新塑形,將水平放置在彼此之下的元素放置在一起。 – Slowpoke

+0

是的,我遵循你的numpy例子,它也垂直疊加張量(沿着第一維)。你修改以水平堆疊它們是正確的。 – user1735003

1

您可以使用assign分配到切片中。

odd_new = tf.constant([1,3,5]) 
even_new = tf.constant([2,4,6]) 
y=tf.Variable(tf.zeros(6, dtype=tf.int32)) 

sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer()) 
y[0::2].assign(odd_new).eval() 
y[1::2].assign(even_new).eval() 
+0

非常感謝!我最初嘗試使用賦值運算符來構建我的網絡,[但面臨一些問題,並且不願意這樣做](https://stackoverflow.com/questions/44737322/)(但是,我認爲代碼是正確的),所以現在我更喜歡使用嵌入張量流函數,所以第一個答案適合我。但是,不過謝謝! – Slowpoke

相關問題