列表理解是非常低效處理numpy數組的方法。對於距離計算來說,它們是一個特別糟糕的選擇。
要找到你的數據和一個點之間的區別,你只需要做data - point
。然後,您可以使用np.hypot
來計算距離,或者如果您願意,可以將其平方,將其相加,並取平方根。
如果您爲了計算目的而將其設置爲Nx2數組,那麼它會更容易一些。
基本上,你想是這樣的:
import numpy as np
data = np.array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395],
[1989, 1332]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399],
[1989, 1333]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399],
[1990, 1333]]])
point = [1989, 1332]
#-- Calculate distance ------------
# The reshape is to make it a single, Nx2 array to make calling `hypot` easier
dist = data.reshape((-1,2)) - point
dist = np.hypot(*dist.T)
# We can then reshape it back to AxBx1 array, similar to the original shape
dist = dist.reshape(data.shape[0], data.shape[1], 1)
print dist
這產生了:
array([[[ 299.48121811],
[ 259.38388539],
[ 45.31004304],
[ 153.5219854 ],
[ 0. ]],
[[ 290.04310025],
[ 254.0019685 ],
[ 52.35456045],
[ 163.37074401],
[ 1. ]],
[[ 280.55837182],
[ 247.34186868],
[ 59.6405902 ],
[ 169.77926846],
[ 1.41421356]]])
現在,除去最近的元素不是簡單地讓最接近的元素更難一點。
隨着numpy,你可以使用布爾索引相當容易地做到這一點。
但是,您需要擔心一些關於軸的對齊方式。
關鍵是要了解沿着軸的numpy「廣播」操作的最後的軸。在這種情況下,我們想沿着中軸進行播客。
此外,-1
可以用作軸的大小的佔位符。當-1
作爲軸的大小放置時,Numpy將計算允許的大小。
什麼我們需要做的看起來有點像這樣:
#-- Remove closest point ---------------------
mask = np.squeeze(dist) != dist.min(axis=1)
filtered = data[mask]
# Once again, let's reshape things back to the original shape...
filtered = filtered.reshape(data.shape[0], -1, data.shape[2])
你可以做一個單一的線,我只是將它分解爲可讀性。關鍵是dist != something
會生成一個布爾數組,然後您可以使用它來索引原始數組。
所以,全部放在一起:
import numpy as np
data = np.array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395],
[1989, 1332]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399],
[1989, 1333]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399],
[1990, 1333]]])
point = [1989, 1332]
#-- Calculate distance ------------
# The reshape is to make it a single, Nx2 array to make calling `hypot` easier
dist = data.reshape((-1,2)) - point
dist = np.hypot(*dist.T)
# We can then reshape it back to AxBx1 array, similar to the original shape
dist = dist.reshape(data.shape[0], data.shape[1], 1)
#-- Remove closest point ---------------------
mask = np.squeeze(dist) != dist.min(axis=1)
filtered = data[mask]
# Once again, let's reshape things back to the original shape...
filtered = filtered.reshape(data.shape[0], -1, data.shape[2])
print filtered
產量:
array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399]]])
在一個側面說明,如果一個以上的點是同樣接近,這是不行的。 Numpy數組必須沿着每個維度具有相同數量的元素,因此在這種情況下您需要重新進行分組。
啊,不知何故,我沒有看到這之前,我張貼。我想過使用'apply_along_axis',但我測試了它,速度更快。 – senderle
'apply_along_axis'應該使用更少的內存,所以這兩種方法仍然有用! –
謝謝!非常簡潔,但內容豐富。太快了。 – OneTrickyPony