2015-10-02 44 views
5

我有一個不規則的(非矩形)lon/lat網格和一些點在lon/lat座標,它應該對應網格上的點(儘管它們由於數字原因可能會稍微偏離)。現在我需要相應的長/點的指數。高效找到非矩形2D網格上的最近點索引

我寫過一個這樣做的函數,但它真的很慢。

def find_indices(lon,lat,x,y): 
    lonlat = np.dstack([lon,lat]) 
    delta = np.abs(lonlat-[x,y]) 
    ij_1d = np.linalg.norm(delta,axis=2).argmin() 
    i,j = np.unravel_index(ij_1d,lon.shape) 
    return i,j 

ind = [find_indices(lon,lat,p*) for p in points] 

我很肯定在numpy/scipy中有一個更好(更快)的解決方案。我已經使用了很多搜索引擎,但是迄今爲止我的答案已經不存在了。

任何建議如何有效地找到相應(最近)點的指數?

PS:這個問題出現了另一個問題(click)。

編輯:解

基於@Cong馬雲的回答,我已經找到了以下解決方案:

def find_indices(points,lon,lat,tree=None): 
    if tree is None: 
     lon,lat = lon.T,lat.T 
     lonlat = np.column_stack((lon.ravel(),lat.ravel())) 
     tree = sp.spatial.cKDTree(lonlat) 
    dist,idx = tree.query(points,k=1) 
    ind = np.column_stack(np.unravel_index(idx,lon.shape)) 
    return [(i,j) for i,j in ind] 

爲了把這個解決方案,並從Divakar的回答也是一個進入的角度看,這裏有該功能的一些時機在我使用find_indices(以及它是在速度方面的瓶頸)(見上面的鏈接):

spatial_contour_frequency/pil0    : 331.9553 
spatial_contour_frequency/pil1    : 104.5771 
spatial_contour_frequency/pil2    :  2.3629 
spatial_contour_frequency/pil3    :  0.3287 

pil0是我最初的方法,pil1 Divakar's和pil2/pil3上面的最終解決方案,其中樹是在pil2(即pil2)中即時創建的。對於調用find_indices的循環的每次迭代)和pil3(參見其他線程的詳細信息)中的一次。儘管Divakar對我最初的方法進行了改進,使我的速度提高了3倍,但cKDTree以另一個50倍提速將這個提升到了一個全新的水平!將樹的創建移出該功能使事情變得更快。

+0

在您的腳本中,每次調用'find_indices'時都會創建一棵新樹。如果你的網格在呼叫中被修復,那麼你就是在浪費時間重複構建相同的樹。實際上,這個樹的構造是這個函數中最慢的調用。 –

+0

是的,我注意到,這就是我現在正在做的事情。 ;)我會相應地更新答案。謝謝你的評論。 – flotzilla

回答

4

如果這些點已經足夠本地化,那麼您可以直接嘗試scipy.spatialcKDTree實現,如我自己in another post所討論的。這篇文章是關於插值的,但你可以忽略它,只使用查詢部分。

TL;博士版本:

閱讀起來scipy.sptial.cKDTree的文檔。通過將(n, m)-形ndarray對象傳遞給初始值設定項來創建樹,並且將從維數座標創建樹。

tree = scipy.spatial.cKDTree(array_of_coordinates) 

之後,使用tree.query()檢索k個最近的鄰居(可能逼近和並行化,見文檔),或使用tree.query_ball_point()找到給定距離的公差範圍內所有鄰居。

如果這些點不是很好的局部化,並且球面曲率/非平凡拓撲起作用,您可以嘗試將多面體分成多個部分,每個部分都足夠小以被認爲是局部的。

1

下面是使用scipy.spatial.distance.cdist一個通用的量化方法 -

import scipy 

# Stack lon and lat arrays as columns to form a Nx2 array, where is N is grid**2 
lonlat = np.column_stack((lon.ravel(),lat.ravel())) 

# Get the distances and get the argmin across the entire N length 
idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0) 

# Get the indices corresponding to grid's shape as the final output 
ind = np.column_stack((np.unravel_index(idx,lon.shape))).tolist() 

採樣運行 -

In [161]: lon 
Out[161]: 
array([[-11. , -7.82 , -4.52 , -1.18 , 2.19 ], 
     [-12. , -8.65 , -5.21 , -1.71 , 1.81 ], 
     [-13. , -9.53 , -5.94 , -2.29 , 1.41 ], 
     [-14.1 , -0.04 , -6.74 , -2.91 , 0.976]]) 

In [162]: lat 
Out[162]: 
array([[-11.2 , -7.82 , -4.51 , -1.18 , 2.19 ], 
     [-12. , -8.63 , -5.27 , -1.71 , 1.81 ], 
     [-13.2 , -9.52 , -5.96 , -2.29 , 1.41 ], 
     [-14.3 , -0.06 , -6.75 , -2.91 , 0.973]]) 

In [163]: lonlat = np.column_stack((lon.ravel(),lat.ravel())) 

In [164]: idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0) 

In [165]: np.column_stack((np.unravel_index(idx,lon.shape))).tolist() 
Out[165]: [[0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [3, 3]] 

運行測試 -

定義功能:

def find_indices(lon,lat,x,y): 
    lonlat = np.dstack([lon,lat]) 
    delta = np.abs(lonlat-[x,y]) 
    ij_1d = np.linalg.norm(delta,axis=2).argmin() 
    i,j = np.unravel_index(ij_1d,lon.shape) 
    return i,j 

def loopy_app(lon,lat,pts): 
    return [find_indices(lon,lat,pts[i,0],pts[i,1]) for i in range(pts.shape[0])] 

def vectorized_app(lon,lat,points): 
    lonlat = np.column_stack((lon.ravel(),lat.ravel())) 
    idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0) 
    return np.column_stack((np.unravel_index(idx,lon.shape))).tolist() 

個時序:

In [179]: lon = np.random.rand(100,100) 

In [180]: lat = np.random.rand(100,100) 

In [181]: points = np.random.rand(50,2) 

In [182]: %timeit loopy_app(lon,lat,points) 
10 loops, best of 3: 47 ms per loop 

In [183]: %timeit vectorized_app(lon,lat,points) 
10 loops, best of 3: 16.6 ms per loop 

對於擠掉更多的性能,np.concatenate可以代替np.column_stack使用。