2017-10-20 192 views
3

假設我有一個三維數組:Numpy:如何使用argmax結果來獲得實際的最大值?

>>> a 
array([[[7, 0], 
     [3, 6]], 

     [[2, 4], 
     [5, 1]]]) 

我可以得到它的argmax沿axis=1使用

>>> m = np.argmax(a, axis=1) 
>>> m 
array([[0, 1], 
     [1, 0]]) 

如何使用m的一個索引a,使結果等同於簡單地使用max

>>> a.max(axis=1) 
array([[7, 6], 
     [5, 4]]) 

(當m被施加到相同的形狀的其他陣列,這是有用)

回答

3

可以與advanced indexingnumpy broadcasting做到這一點:

m = np.argmax(a, axis=1) 
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])] 

#array([[7, 6], 
#  [5, 4]]) 

m = np.argmax(a, axis=1) 

創建第一,第二和第三維度索引數組:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2]) 
​ 

由於尺寸不匹配的,這三個陣列將廣播,這導致各爲如下:

for x in np.broadcast_arrays(ind1, ind2, ind3): 
    print(x, '\n') 

#[[0 0] 
# [1 1]] 

#[[0 1] 
# [1 0]] 

#[[0 1] 
# [0 1]] 

由於所有索引是整數陣列,它觸發advanced indexing,所以具有索引元件(0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1)被拾取,即來自每個陣列的一個元素組合爲索引。

2

您可以使用np.ogrid爲您的陣列在所有軸上創建一個網格,但縮小的網格除外。然後,只需插入argmax結果在您的軸和index你,結果陣列的位置:

>>> import numpy as np 
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]]) 
>>> axis = 1 

>>> # Create the grid 
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]]) 
>>> argmaxes = np.argmax(a, axis=axis) 
>>> idx.insert(axis, argmaxes) 

>>> # Index the original array with the grid 
>>> a[idx] 
array([[7, 6], 
     [5, 4]]) 

請注意,這並不爲axis=None或情況下工作,你多軸減少了。

相關問題