2017-07-25 266 views
5

問題

我有兩個numpy的陣列,Aindicesnumpy的匹配索引尺寸

A具有尺寸m x n x 10000. indices具有尺寸m x n x 5(從argpartition(A, 5)[:,:,:5]輸出)。 我想得到一個m×n×5的數組,其中包含對應於indicesA的元素。

嘗試

indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]], 
    [500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]]) 
A = np.reshape(range(2 * 3 * 10000), (2,3,10000)) 

A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values 
np.take(A, indices) # shape is right, but it flattens the array first 
np.choose(indices, A) # fails because of shape mismatch. 

動機

我試圖得到A[i,j] 5個最大值爲每i<mj<n使用np.argpartition因爲陣列可以得到相當大的排序順序。

回答

5

您可以使用advanced-indexing -

m,n = A.shape[:2] 
out = A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices] 

採樣運行 -

In [330]: A 
Out[330]: 
array([[[38, 21, 61, 74, 35, 29, 44, 46, 43, 38], 
     [22, 44, 89, 48, 97, 75, 50, 16, 28, 78], 
     [72, 90, 48, 88, 64, 30, 62, 89, 46, 20]], 

     [[81, 57, 18, 71, 43, 40, 57, 14, 89, 15], 
     [93, 47, 17, 24, 22, 87, 34, 29, 66, 20], 
     [95, 27, 76, 85, 52, 89, 69, 92, 14, 13]]]) 

In [331]: indices 
Out[331]: 
array([[[7, 8, 1], 
     [7, 4, 7], 
     [4, 8, 4]], 

     [[0, 7, 4], 
     [5, 3, 1], 
     [1, 4, 0]]]) 

In [332]: m,n = A.shape[:2] 

In [333]: A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices] 
Out[333]: 
array([[[46, 43, 21], 
     [16, 97, 16], 
     [64, 46, 64]], 

     [[81, 14, 43], 
     [87, 24, 47], 
     [27, 52, 95]]]) 

爲了得到相對應的最大沿最後軸5種元素的索引,我們將使用argpartition,像這樣 -

indices = np.argpartition(-A,5,axis=-1)[...,:5] 

爲了保持訂單從最高到最低,我們e range(5)而不是5

1

爲子孫後代,下面採用Divakar的答案來完成原來的目標,即在排序的順序返回的前5名值的所有i<m, j<n

m, n = np.shape(A)[:2] 

# get the largest 5 indices for all m, n 
top_unsorted_indices = np.argpartition(A, -5, axis=2)[...,-5:] 

# get the values corresponding to top_unsorted_indices 
top_values = A[np.arange(m)[:,None,None], np.arange(n)[:,None], top_unsorted_indices] 

# sort the top 5 values 
top_sorted_indices = top_unsorted_indices[np.arange(m)[:,None,None], np.arange(n)[:,None], np.argsort(-top_values)]