2016-11-03 39 views
-1

我想找到一個numpy的矩陣最大的行和列的索引。但它不在一組行或列中。因此,它應該在計算最大值時跳過這些行和列。查找Max在numpy的跳過一些行和列

例子:

# finding max in numpy matrix 
[row,col] = np.where(mat == mat.max()) 

但它應該跳過行removed_rows = []和列columns_rows = []

我不想創建計算一個新的子矩陣。

+0

數組/矩陣是否有負數? – Divakar

+0

@Divakar不,它只包含非負數。 –

+0

可以更改輸入數組嗎? – Divakar

回答

3

假設a是輸入數組,rows_remcols_rem分別是要跳過的行和列索引。我們將使用掩蔽有一個辦法,像這樣 -

m,n = a.shape 
d0,d1 = np.ogrid[:m,:n] 
a_masked = a*~(np.in1d(d0,rows_rem)[:,None] | np.in1d(d1,cols_rem)) 
max_row, max_col = np.where(a_masked == a_masked.max()) 

採樣運行 -

In [204]: # Inputs 
    ...: a = np.random.randint(11,99,(4,5)) 
    ...: rows_rem = [1,3] 
    ...: cols_rem = [1,2,4] 
    ...: 

In [205]: a 
Out[205]: 
array([[36, 51, 72, 18, 31], 
     [78, 42, 12, 71, 72], 
     [38, 46, 42, 67, 12], 
     [87, 56, 76, 14, 21]]) 

In [206]: a_masked 
Out[206]: 
array([[64, 0, 0, 90, 0], 
     [ 0, 0, 0, 0, 0], 
     [17, 0, 0, 40, 0], 
     [ 0, 0, 0, 0, 0]]) 

In [207]: max_row, max_col 
Out[207]: (array([0]), array([3])) 

請注意,如果使用相同的最大值超過一個元素是,我們將有所有那些在輸出。所以,如果你想要的任何或第一的,我們可以用argmax,像這樣 -

max_row, max_col = np.unravel_index(a_masked.argmax(),a.shape) 
+0

任何通過切片的方式,跳過這些行和列。 –

+0

請檢查問題中的編輯。它給出了不在矩陣中的元素。 –

0
remove_rows = [2,3] 
remove_cols = [0,1] 

a = np.random.randint(11,99,(4,5)) 

>>> a 
array([[60, 86, 89, 66, 20], 
     [77, 86, 78, 90, 44], 
     [68, 57, 83, 48, 25], 
     [30, 81, 42, 11, 63]]) 
>>> 

把那您有興趣通過過濾出你想要刪除索引的行和列索引:

r, c = a.shape 
r = [x for x in range(r) if x not in remove_rows] 
c = [x for x in range(c) if x not in remove_cols] 

>>> r,c 
([0, 1], [2, 3, 4]) 
>>> 

現在rc可用於integer indexingnumpy.ix_幫助與此有關。

>>> a[np.ix_(r,c)] 
array([[89, 66, 20], 
     [78, 90, 44]]) 
>>> 

釘在ndarray.max()來獲取最大價值:

>>> a[np.ix_(r,c)].max() 
90 
>>> 

最後,找到使用numpy.where它是原始數組中:

>>> row, col = np.where(a == a[np.ix_(r,c)].max()) 
>>> row, col 
(array([1]), array([3])) 
>>> 

這種方法也將工作,如果除去非連續行或列。 例如:

remove_rows = [0,3] 
remove_cols = [1,4] 
+0

您的編輯答案也會發生同樣的情況。它給出了不在矩陣中的元素。請用'a = np.array([[11.,1。],[1.,5。],[1.,11。]])''檢查。 –

+0

@AbhishekBhatia,我不明白:你想要我檢查什麼/你怎麼檢查a = np.array([[11.,1。],[1.,5。],[1.,11 ]])'''? – wwii