2015-06-02 17 views
4

我有一個名爲A的Nx2x2x2數組。我也有一個名爲B的Nx2數組,它告訴我A中最後兩個維度的位置,我感興趣。我正在通過使用循環(如下面代碼中的C)或使用列表理解(如下面的代碼中的D)來獲取Nx2數組。我想知道矢量化是否會帶來時間收益,如果是的話,如何矢量化這個任務。Numpy:多維索引。一行一行沒有循環

我目前的矢量化方法(下面的代碼中的E)是使用B來索引A的每個子矩陣,所以它不會返回我想要的。我想e。返回相同的C或D.

輸入:

A=np.reshape(np.arange(0,32),(4,2,2,2)) 
print("A") 
print(A) 
B=np.concatenate((np.array([0,1,0,1])[:,np.newaxis],np.array([1,1,0,0])[:,np.newaxis]),axis=1) 
print("B") 
print(B) 
C=np.empty(shape=(4,2)) 
for n in range(0, 4): 
    C[n,:]=A[n,:,B[n,0],B[n,1]] 
print("C") 
print(C) 
D = np.array([A[n,:,B[n,0],B[n,1]] for n in range(0, 4)]) 
print("D") 
print(D) 
E=A[:,:,B[:,0],B[:,1]] 
print("E") 
print(E) 

輸出:

A 
[[[[ 0 1] 
    [ 2 3]] 

    [[ 4 5] 
    [ 6 7]]] 


[[[ 8 9] 
    [10 11]] 

    [[12 13] 
    [14 15]]] 


[[[16 17] 
    [18 19]] 

    [[20 21] 
    [22 23]]] 


[[[24 25] 
    [26 27]] 

    [[28 29] 
    [30 31]]]] 
B 
[[0 1] 
[1 1] 
[0 0] 
[1 0]] 
C 
[[ 1. 5.] 
[ 11. 15.] 
[ 16. 20.] 
[ 26. 30.]] 
D 
[[ 1 5] 
[11 15] 
[16 20] 
[26 30]] 
E 
[[[ 1 3 0 2] 
    [ 5 7 4 6]] 

[[ 9 11 8 10] 
    [13 15 12 14]] 

[[17 19 16 18] 
    [21 23 20 22]] 

[[25 27 24 26] 
    [29 31 28 30]]] 

回答

3

複雜切片操作可以以量化的方式來完成,像這樣 -

shp = A.shape 
out = A.reshape(shp[0],shp[1],-1)[np.arange(shp[0]),:,B[:,0]*shp[3] + B[:,1]] 

您正在使用第一列和第二列B來索引輸入4D陣列的第三和第四維,A。這意味着,基本上你正在切割4D陣列,最後兩個維度是融合在一起。因此,您需要使用B獲得具有該融合格式的線性索引。當然,在完成所有這些之前,您需要將A重塑爲具有A.reshape(shp[0],shp[1],-1)的3D陣列。

驗證結果的通用4D陣列情況下 -

In [104]: A = np.random.rand(6,3,4,5) 
    ...: B = np.concatenate((np.random.randint(0,4,(6,1)),np.random.randint(0,5,(6,1))),1) 
    ...: 

In [105]: C=np.empty(shape=(6,3)) 
    ...: for n in range(0, 6): 
    ...:  C[n,:]=A[n,:,B[n,0],B[n,1]] 
    ...:  

In [106]: shp = A.shape 
    ...: out = A.reshape(shp[0],shp[1],-1)[np.arange(shp[0]),:,B[:,0]*shp[3] + B[:,1]] 
    ...: 

In [107]: np.allclose(C,out) 
Out[107]: True