2014-10-12 57 views
6

我有一個較大的2d numpy數組,我想提取每行的最低10個元素以及它們的索引。由於我的數組非常大,我寧願不對整個數組進行排序。如何將numpy.argpartition的輸出應用於二維數組?

我聽說argpartition()功能,與我能得到最低的10種元素的索引:

top10indexes = np.argpartition(myBigArray,10)[:,:10] 

注意argpartition()分區軸-1默認情況下,這是我想要的。這裏的結果與包含索引的myBigArray具有相同的形狀,因此前10個索引指向10個最低值。

我該如何提取對應於這些索引的myBigArray的元素?

明顯的花式索引像myBigArray[top10indexes]myBigArray[:,top10indexes]做了很大的不同。我還可以使用列表解析,是這樣的:

array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)]) 

但這樣會導致性能命中迭代numpy的行並將結果轉換回一個數組。

nb:我可以使用np.partition()來獲取值,它們甚至可能對應於索引(或可能不是),但是如果我可以避免它,我不想做兩次分區。

回答

6

可以儘量避免使用扁平副本,需要做提取所有的值:

num = 10 
top = np.argpartition(myBigArray, num, axis=1)[:, :num] 
myBigArray[np.arange(myBigArray.shape[0])[:, None], top] 

對於與NumPy> = 1.9.0,這將是非常有效和媲美np.take()

+2

我用'flatten()'刪除了我的答案。我研究了它爲什麼不起作用,但看不到任何簡單的方法來修復它,而沒有有效地製作出更復雜的版本! – 2014-10-12 12:20:35

+1

gr8!我也瞭解到'None'和'newaxis'在這裏扮演着同樣的角色:) btw,在你的答案中'arr'應該是'myBigArray',以防我的編輯不被接受。 – drevicko 2014-10-12 22:26:06