2017-02-19 61 views
1

我知道我可以使用argminunravel_index來查找ndarray中最小值的索引,但是如果我想查找最小非零元素或最小元素不是NaN,該怎麼辦?Conditional nd argmin:如何找到多維數組子集的最小值的座標?

+0

對於NaN的情況有根本'numpy.nanargmin' – user7138814

+0

@ user7138814還不錯在那裏的建議。添加到我的帖子中的時間測試,看起來非常有效! – Divakar

+0

@Demetri P發佈的解決方案是否適合您? – Divakar

回答

2

這裏有一個方法使用扁平指數 -

def flatnonzero_based(a,condition): # condition = a!= or ~np.isnan(a) 
    idx = np.flatnonzero(condition) 
    return np.unravel_index(idx[np.take(a, idx).argmin()], a.shape) 

標杆

方法 -

def flatnonzero_based(a,condition): # Proposed soln 
    idx = np.flatnonzero(condition) 
    return np.unravel_index(idx[np.take(a, idx).argmin()], a.shape) 

def where_based(a, condition): # @Paul Panzer's soln 
    nz = np.where(condition) 
    return np.array(nz)[:, np.argmin(a[nz])] 

時序和驗證 -

In [233]: a = np.random.rand(40,50,30) 

In [234]: nan_idx = np.random.choice(range(a.size), size = a.size//100, replace=0) 

In [235]: a.ravel()[nan_idx] = np.nan 

In [236]: condition = ~np.isnan(a) 

In [237]: where_based(a, condition) 
Out[237]: array([16, 10, 8]) 

In [238]: flatnonzero_based(a, condition) 
Out[238]: (16, 10, 8) 

In [239]: %timeit where_based(a, condition) 
1000 loops, best of 3: 877 µs per loop 

In [240]: %timeit flatnonzero_based(a, condition) 
10000 loops, best of 3: 143 µs per loop 

隨着4D數據 -

In [255]: a = np.random.rand(40,50,30,30) 

In [256]: nan_idx = np.random.choice(range(a.size), size = a.size//100, replace=0) 

In [257]: a.ravel()[nan_idx] = np.nan 

In [258]: condition = ~np.isnan(a) 

In [259]: where_based(a, condition) 
Out[259]: array([34, 14, 5, 10]) 

In [260]: flatnonzero_based(a, condition) 
Out[260]: (34, 14, 5, 10) 

In [261]: %timeit where_based(a, condition) 
10 loops, best of 3: 64.9 ms per loop 

In [262]: %timeit flatnonzero_based(a, condition) 
100 loops, best of 3: 5.32 ms per loop 

結合@user7138814's suggestion -

In [267]: np.unravel_index(np.nanargmin(a), a.shape) 
Out[267]: (34, 14, 5, 10) 

In [268]: %timeit np.unravel_index(np.nanargmin(a), a.shape) 
100 loops, best of 3: 4.54 ms per loop 
1

這應該工作(條件是數據= 0或〜np.isnan(數據)!)

nz = np.where(condition) 
cond_arg_min = np.array(nz)[:, np.argmin(data[nz])] 
相關問題