2017-01-24 69 views
2

我正在尋找一種更優化的方法來將(n,n)或(n,n,1)矩陣轉換爲(n,n,3)矩陣。我從一個(n,n,3)開始,但是當我在第二個軸上對(n,n)進行求和後,我的尺寸減小了。實質上,我想保持數組的原始大小,並讓第二個軸重複3次。我需要這個的原因是我稍後會用另一個(n,n,3)陣列播放它,但它們需要相同的尺寸。改變不同尺寸的陣列一起播出

我目前的方法有效,但看起來並不高雅。

a0=np.random.random((n,n)) 
b=a.flatten().tolist() 
a=np.array(zip(b,b,b)) 
a.shape=n,n,3 

此設置具有所需的結果,但是笨重且難以遵循。是否有可能通過複製第二個索引直接從(n,n)到(n,n,3)?或者可能是一種不縮小數組的開始?

回答

1

可以首先在a創建一個新的軸(軸= 2),然後使用np.repeat沿着該新的軸:

np.repeat(a[:,:,None], 3, axis = 2) 

或者另一種方法中,壓平的數組,重複的元素,然後重塑:

np.repeat(a.ravel(), 3).reshape(n,n,3) 

結果比較:

import numpy as np 
n = 4 
a=np.random.random((n,n)) 
b=a.flatten().tolist() 
a1=np.array(zip(b,b,b)) 
a1.shape=n,n,3 
# a1 is the result from the original method 

(np.repeat(a[:,:,None], 3, axis = 2) == a1).all() 
# True 

(np.repeat(a.ravel(), 3).reshape(4,4,3) == a1).all() 
# True 

定時,使用內置numpy.repeat還示出了加速:

import numpy as np 
n = 4 
a=np.random.random((n,n)) 
​ 
def rep(): 
    b=a.flatten().tolist() 
    a1=np.array(zip(b,b,b)) 
    a1.shape=n,n,3 

%timeit rep() 
# 100000 loops, best of 3: 7.11 µs per loop 

%timeit np.repeat(a[:,:,None], 3, axis = 2) 
# 1000000 loops, best of 3: 1.64 µs per loop 

%timeit np.repeat(a.ravel(), 3).reshape(4,4,3) 
# 1000000 loops, best of 3: 1.9 µs per loop 
2

Nonenp.newaxis是增加一個維度的陣列的常見方式。 reshape與(3,3,1)的作品一樣好:

In [64]: arr=np.arange(9).reshape(3,3) 
In [65]: arr1 = arr[...,None] 
In [66]: arr1.shape 
Out[66]: (3, 3, 1) 

repeat爲函數或方法複製本。

In [72]: arr2=arr1.repeat(3,axis=2) 
In [73]: arr2.shape 
Out[73]: (3, 3, 3) 
In [74]: arr2[0,0,:] 
Out[74]: array([0, 0, 0]) 

但你可能不需要這樣做。廣播(3,3,1)與(3,3,3)一起工作。

In [75]: (arr1+arr2).shape 
Out[75]: (3, 3, 3) 

事實上,它會播放一個(3,)產生(3,3,3)。

In [77]: arr1+np.ones(3,int) 
Out[77]: 
array([[[1, 1, 1], 
     [2, 2, 2], 
     ... 
     [[7, 7, 7], 
     [8, 8, 8], 
     [9, 9, 9]]]) 

因此arr1+np.zeros(3,int)是將(3,3,1)擴展到(3,3,3)的另一種方式。

廣播規則是:

(3,3,1) + (3,) => (3,3,1) + (1,1,3) => (3,3,3) 

廣播需要在開始增加尺寸。

當您在軸上綜上​​所述,你可以保持尺寸的原號碼與參數:

In [78]: arr2.sum(axis=2).shape 
Out[78]: (3, 3) 
In [79]: arr2.sum(axis=2, keepdims=True).shape 
Out[79]: (3, 3, 1) 

,如果你想從沿任意維度的數組中減去平均這是很方便

arr2-arr2.mean(axis=2, keepdims=True)