我在學習工作中實現了一個基本的最近鄰居搜索。 事實上,基本的numpy實現運行良好,但只是添加'@jit'裝飾器(在Numba中編譯),輸出是不同的(它由於某些未知原因而複製了一些鄰居......)numba輸出的差異
這是基本的算法:
import numpy as np
from numba import jit
@jit(nopython=True)
def knn(p, points, k):
'''Find the k nearest neighbors (brute force) of the point p
in the list points (each row is a point)'''
n = p.size # Lenght of the points
M = points.shape[0] # Number of points
neighbors = np.zeros((k,n))
distances = 1e6*np.ones(k)
for i in xrange(M):
d = 0
pt = points[i, :] # Point to compare
for r in xrange(n): # For each coordinate
aux = p[r] - pt[r]
d += aux * aux
if d < distances[k-1]: # We find a new neighbor
pos = k-1
while pos>0 and d<distances[pos-1]: # Find the position
pos -= 1
pt = points[i, :]
# Insert neighbor and distance:
neighbors[pos+1:, :] = neighbors[pos:-1, :]
neighbors[pos, :] = pt
distances[pos+1:] = distances[pos:-1]
distances[pos] = d
return neighbors, distances
來進行測試:
p = np.random.rand(10)
points = np.random.rand(250, 10)
k = 5
neighbors = knn(p, points, k)
沒有@jit裝飾,可以得到正確的答案:
In [1]: distances
Out[1]: array([ 0.3933974 , 0.44754336, 0.54548715, 0.55619749, 0.5657846 ])
個
但Numba編譯給人怪異的輸出:
Out[2]: distances
Out[2]: array([ 0.3933974 , 0.44754336, 0.54548715, 0.54548715, 0.54548715])
有人能幫忙嗎?我不知道爲什麼會發生...
謝謝你。
你可能有興趣在SciPy的[KDTree](http://docs.scipy.org/doc/scipy/ reference/generated/scipy.spatial.cKDTree.html)實現。 – Daniel
@Ophion感謝您的提示。我一直在玩sklearn的KDTree實現(我猜它們是相似的),它們對預處理未來多個查詢點的數據非常有用。在我的工作中,我需要找到鄰居總是改變點列表(在圖像處理中),而這種實現變得太慢了。而且,當空間維數很大(例如大於25)時,似乎KDTree的實現並不比蠻力更好。 –