1
我有一個2維NumPy ndarray。如何在ndarray尋找所有argmax
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])
如何獲取最大元素的所有索引?所以我想作爲輸出數組([0,1],[2,2])。
我有一個2維NumPy ndarray。如何在ndarray尋找所有argmax
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])
如何獲取最大元素的所有索引?所以我想作爲輸出數組([0,1],[2,2])。
使用np.argwhere
上最大平等掩蓋 -
np.argwhere(a == a.max())
採樣運行 -
In [552]: a # Input array
Out[552]:
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])
In [553]: a == a.max() # Max equality mask
Out[553]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)
In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]:
array([[0, 1],
[2, 2]])
如果您正在使用浮點數工作,你可能需要使用一些寬容那裏。所以,考慮到這一點,你可以使用np.isclose
,它有一些默認的絕對和相對容差值。這將取代早期的a == a.max()
部分,像這樣 -
In [555]: np.isclose(a, a.max())
Out[555]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)