0
我有一個Theano tensor3
(即,3維陣列)x
:移基於偏移矢量tensor3元素的位置
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
以及一個Theano向量(即,一維陣列)y
,我們將參考作爲「偏移」向量,因爲它指定了期望的偏移:
[2, 1]
欲轉移基於矢量y
的x
元素的位置,從而使輸出如下(在第二維進行變速):
[[[ a b c d]
[ e f g h]
[ 0 1 2 3]]
[[ i j k l]
[12 13 14 15]
[16 17 18 19]]]
其中a
,b
,...,l
可以是任何數目。
例如,一個有效的輸出可能是:
[[[ 0 0 0 0]
[ 0 0 0 0]
[ 0 1 2 3]]
[[ 0 0 0 0]
[12 13 14 15]
[16 17 18 19]]]
另一種有效的輸出可能是:
[[[ 4 5 6 7]
[ 8 9 10 11]
[ 0 1 2 3]]
[[20 21 22 23]
[12 13 14 15]
[16 17 18 19]]]
我知道的功能theano.tensor.roll(x, shift, axis=None)
的,但是shift
只能拿一個標量作爲輸入,即它移動所有具有相同偏移量的元素。
例如,代碼:
import theano.tensor
from theano import shared
import numpy as np
x = shared(np.arange(24).reshape((2,3,4)))
print('theano.tensor.roll(x, 2, axis=1).eval(): \n{0}'.
format(theano.tensor.roll(x, 2, axis=1).eval()))
輸出:
theano.tensor.roll(x, 2, axis=1).eval():
[[[ 4 5 6 7]
[ 8 9 10 11]
[ 0 1 2 3]]
[[16 17 18 19]
[20 21 22 23]
[12 13 14 15]]]
這不是我想要的。
如何根據偏移矢量移動tensor3
元素的位置? (注意在這個例子中提供的代碼中,爲了方便起見,tensor3
是一個共享變量,但在我的實際代碼中它將是一個符號變量)