2017-08-27 55 views
1

好吧,這裏給出的數據; 形狀有三個numpy陣列: (i,4,2),(i,4,3),(i,4,2) i在它們之間共享,但是是可變的。 dtype是float32的一切。 目標是按特定順序交織它們。讓我們看一下數據索引0對於這些陣列:Numpy交織異形陣列

[[-208. -16.] 
[-192. -16.] 
[-192. 0.] 
[-208. 0.]] 

[[ 1. 1. 1.] 
[ 1. 1. 1.] 
[ 1. 1. 1.] 
[ 1. 1. 1.]] 

[[ 0.49609375 0.984375 ] 
[ 0.25390625 0.984375 ] 
[ 0.25390625 0.015625 ] 
[ 0.49609375 0.015625 ]] 

在這種情況下,concatened目標陣列會是這個樣子:

[-208, -16, 1, 1, 1, 0.496, 0.984, -192, -16, 1, 1, 1, ...] 

然後繼續與索引1

我不知道如何實現這一點,因爲連接函數只是告訴我形狀不匹配。目標數組的形狀並不重要,只是它的內存視圖必須按照給定順序才能上傳到GPU着色器。

編輯:我可以用幾個python for循環來實現,但是性能影響會成爲這個程序中的一個問題。

回答

2

使用np.dstack和扁平化與np.ravel() -

np.dstack((a,b,c)).ravel() 

現在,np.dstack基本上沿着第三軸堆疊。所以,我們也可以使用np.concatenate沿軸過,像這樣 -

np.concatenate((a,b,c),axis=2).ravel() 

採樣運行 -

1)設置輸入數組:

In [613]: np.random.seed(1234) 
    ...: n = 3 
    ...: m = 2 
    ...: a = np.random.randint(0,9,(n,m,2)) 
    ...: b = np.random.randint(11,99,(n,m,2)) 
    ...: c = np.random.randint(101,999,(n,m,2)) 
    ...: 

2)檢查輸入值:

In [614]: a 
Out[614]: 
array([[[3, 6], 
     [5, 4]], 

     [[8, 1], 
     [7, 6]], 

     [[8, 0], 
     [5, 0]]]) 

In [615]: b 
Out[615]: 
array([[[84, 58], 
     [61, 87]], 

     [[48, 45], 
     [49, 78]], 

     [[22, 11], 
     [86, 91]]]) 

In [616]: c 
Out[616]: 
array([[[104, 359], 
     [376, 560]], 

     [[472, 720], 
     [566, 115]], 

     [[344, 556], 
     [929, 591]]]) 

3)輸出:

In [617]: np.dstack((a,b,c)).ravel() 
Out[617]: 
array([ 3, 6, 84, 58, 104, 359, 5, 4, 61, 87, 376, 560, 8, 
     1, 48, 45, 472, 720, 7, 6, 49, 78, 566, 115, 8, 0, 
     22, 11, 344, 556, 5, 0, 86, 91, 929, 591]) 
+0

這個工作。你的示例形狀稍微偏離了一點,但它確實有效。 謝謝! – Berserker

+0

@Berserker Yeah'm = 4'對於你的情況,但是這樣會產生巨大的數組,這對於樣本運行來說太多的數據,所以縮短了它:) – Divakar

+0

我更多地提到「b」的形狀在第三軸不是3。 – Berserker

0

我會做的是:

np.hstack([a, b, c]).flatten()

假設,b,c是三個數組