2017-08-01 53 views
2

我有一個全球numpy.array 數據這是一個200 * 200 * 3的三維陣列在三維空間中包含40000點。 (0,0,0),(1,0,0),(0,1,0),(0,0,0,0),(1,0,0,0,0), 0,1)),所以我可以確定哪個角落離它最近。Python如何提高numpy數組的性能?

def dist(*point): 
    return np.linalg.norm(data - np.array(rgb), axis=2) 

buffer = np.stack([dist(0, 0, 0), dist(1, 0, 0), dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0) 

我寫了這段代碼並進行了測試,每次運行約耗時10ms。 我的問題是如何提高這段代碼的性能,更好地在不到1ms的時間內運行。

+0

不立方體有超過4個角? –

+1

@JohnZwinck只需要計算其中四個的距離。 – iouvxz

回答

3

你可以使用Scipy cdist -

# unit cube coordinates as array 
uc = np.array([[0, 0, 0],[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 

# buffer output 
buf = cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1) 

運行測試

# Original approach 
def org_app(): 
    return np.stack([dist(0, 0, 0), dist(1, 0, 0), \ 
     dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0) 

計時 -

In [170]: data = np.random.rand(200,200,3) 

In [171]: %timeit org_app() 
100 loops, best of 3: 4.24 ms per loop 

In [172]: %timeit cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1) 
1000 loops, best of 3: 1.25 ms per loop