2011-08-23 41 views
3

我有一個2D屏蔽的值數組,我需要從最低到最高排序。例如:NumPy:從閾值以上和以下的屏蔽二維數組中查找排序的索引

import numpy as np 

# Make a random masked array 
>>> ar = np.ma.array(np.round(np.random.normal(50, 10, 20), 1), 
        mask=np.random.binomial(1, .2, 20)).reshape((4,5)) 
>>> print(ar) 
[[-- 51.9 38.3 46.8 43.3] 
[52.3 65.0 51.2 46.5 --] 
[56.7 51.1 -- 38.6 33.5] 
[45.2 56.8 74.1 58.4 56.4]] 

# Sort the array from lowest to highest, with a flattened index 
>>> sorted_ind = ar.argsort(axis=None) 
>>> print(sorted_ind) 
[14 2 13 4 15 8 3 11 7 1 5 19 10 16 18 6 17 0 12 9] 

但隨着分類指數方面,我需要把它們分爲兩個簡單的子集:比小於或等於大於或等於一個給定的數據。此外,我不需要蒙版值,他們需要被刪除。例如,對於datum = 51.1,如何將sorted_ind過濾到datum以上的10個索引以及8個以下的值? (注意:由於或等於邏輯標準,有一個共享索引,可以從分析中刪除3個掩碼值)。我需要保留扁平的索引位置,因爲稍後我會使用np.unravel_index(ind, ar.shape)

+0

這也可以用numpy的'where'函數來優雅地實現 – wim

+0

@wim:我正在試着用'where'的技巧,但是我無法理解輸出。雖然不太「Numpythonic」,但「過濾器」技術看起來工作得很好。 –

回答

5

使用其中:

import numpy as np 
np.random.seed(0) 
# Make a random masked array 
ar = np.ma.array(np.round(np.random.normal(50, 10, 20), 1), 
        mask=np.random.binomial(1, .2, 20)).reshape((4,5)) 
# Sort the array from lowest to highest, with a flattened index 
sorted_ind = ar.argsort(axis=None) 

tmp = ar.flatten()[sorted_ind] 
print sorted_ind[np.ma.where(tmp<=51.0)] 
print sorted_ind[np.ma.where(tmp>=51.0)] 

但由於TMP的排序,你可以使用np.searchsorted() :

tmp = ar.flatten()[sorted_ind].compressed() # compressed() will delete all invalid data. 
idx = np.searchsorted(tmp, 51.0) 
print sorted_ind[:idx] 
print sorted_ind[idx:len(tmp)] 
+0

你的第一個例子是一個很好的例子,如何使用'where',並且比'filter'更快。第二個例子並不像預期的那樣工作,因爲在我的情況下,兩個子集可以重疊(由於*或等於*條件)。 –

3

的準備:

>>> ar = np.ma.array(np.round(np.random.normal(50, 10, 20), 1), 
        mask=np.random.binomial(1, .2, 20)).reshape((4,5)) 
>>> print(ar) 
[[59.9 51.3 -- 19.7 --] 
[59.1 57.2 48.6 49.8 46.3] 
[51.1 61.6 36.9 52.2 51.7] 
[37.9 -- -- 53.1 57.5]] 
>>> sorted_ind = ar.argsort(axis=None) 
>>> sorted_ind 
array([ 3, 12, 15, 9, 7, 8, 10, 1, 14, 13, 18, 6, 19, 5, 0, 11, 4, 
     2, 16, 17]) 

然後把新的東西

>>> flat = ar.flatten() 
>>> leq_ind = filter(lambda x: flat[x] <= 51.1, sorted_ind) 
>>> leq_ind 
[3, 12, 15, 9, 7, 8, 10] 
>>> geq_ind = filter(lambda x: flat[x] >= 51.1, sorted_ind) 
>>> geq_ind 
[10, 1, 14, 13, 18, 6, 19, 5, 0, 11] 
+0

太棒了!你只需要從'<=' to '> ='改變邏輯符號''geq_ind' –

+0

哈哈粘貼了錯誤的代碼,對不起。固定。 –

+0

將爲sorted_ind中的每個元素調用ar.flatten()。它會每次創建一個扁平數組。對於大型陣列,它非常緩慢。 – HYRY