2016-08-15 119 views
0

我有一個Theano tensor3(即,3維陣列)x移基於偏移矢量tensor3元素的位置

[[[ 0 1 2 3] 
    [ 4 5 6 7] 
    [ 8 9 10 11]] 

[[12 13 14 15] 
    [16 17 18 19] 
    [20 21 22 23]]] 

以及一個Theano向量(即,一維陣列)y,我們將參考作爲「偏移」向量,因爲它指定了期望的偏移:

[2, 1] 

欲轉移基於矢量yx元素的位置,從而使輸出如下(在第二維進行變速):

[[[ a b c d] 
    [ e f g h] 
    [ 0 1 2 3]] 

[[ i j k l] 
    [12 13 14 15] 
    [16 17 18 19]]] 

其中ab,...,l可以是任何數目。

例如,一個有效的輸出可能是:

[[[ 0 0 0 0] 
    [ 0 0 0 0] 
    [ 0 1 2 3]] 

[[ 0 0 0 0] 
    [12 13 14 15] 
    [16 17 18 19]]] 

另一種有效的輸出可能是:

[[[ 4 5 6 7] 
    [ 8 9 10 11] 
    [ 0 1 2 3]] 

[[20 21 22 23] 
    [12 13 14 15] 
    [16 17 18 19]]] 

我知道的功能theano.tensor.roll(x, shift, axis=None)的,但是shift只能拿一個標量作爲輸入,即它移動所有具有相同偏移量的元素。

例如,代碼:

import theano.tensor 
from theano import shared 
import numpy as np 

x = shared(np.arange(24).reshape((2,3,4))) 
print('theano.tensor.roll(x, 2, axis=1).eval(): \n{0}'. 
     format(theano.tensor.roll(x, 2, axis=1).eval())) 

輸出:

theano.tensor.roll(x, 2, axis=1).eval(): 
[[[ 4 5 6 7] 
    [ 8 9 10 11] 
    [ 0 1 2 3]] 

[[16 17 18 19] 
    [20 21 22 23] 
    [12 13 14 15]]] 

這不是我想要的。

如何根據偏移矢量移動tensor3元素的位置? (注意在這個例子中提供的代碼中,爲了方便起見,tensor3是一個共享變量,但在我的實際代碼中它將是一個符號變量)

回答

0

我找不到任何專用的功能,所以我簡單地結束了使用theano.scan

import theano 
import theano.tensor 

from theano import shared 
import numpy as np 

y = shared(np.array([2,1])) 
x = shared(np.arange(24).reshape((2,3,4))) 
print('x.eval():\n{0}\n'.format(x.eval())) 

def shift_and_reverse_row(matrix, y):  
    ''' 
    Shift and reverse the matrix in the direction of the first dimension (i.e., rows) 
    matrix: matrix 
    y: scalar 
    ''' 
    new_matrix = theano.tensor.zeros_like(matrix) 
    new_matrix = theano.tensor.set_subtensor(new_matrix[:y,:], matrix[y-1::-1,:]) 
    return new_matrix 

new_x, updates = theano.scan(shift_and_reverse_row, outputs_info=None, 
          sequences=[x, y[::-1]]) 
new_x = new_x[:, ::-1, :] 
print('new_x.eval(): \n{0}'.format(new_x.eval())) 

輸出:

x.eval(): 
[[[ 0 1 2 3] 
    [ 4 5 6 7] 
    [ 8 9 10 11]] 

[[12 13 14 15] 
    [16 17 18 19] 
    [20 21 22 23]]] 

new_x.eval(): 
[[[ 0 0 0 0] 
    [ 0 0 0 0] 
    [ 0 1 2 3]] 

[[ 0 0 0 0] 
    [12 13 14 15] 
    [16 17 18 19]]]