2016-11-18 28 views
4

考慮陣列a發現最大的分之前來到

a = np.array([ 
     [5, 4], 
     [4, 5], 
     [2, 2], 
     [6, 1], 
     [3, 7] 
    ]) 

我能找到其中的最小值與

a.argmin(0) 

array([2, 3]) 

如何找到0列的最大值在索引2之前。第1列和索引3也是如此。更重要的是,它們在哪裏?

如果我做

a.max(0) 

array([6, 7]) 

,但我需要

# max values 
array([5, 5]) 

# argmax before mins 
array([0, 1]) 

回答

4

下面是使用一種方法broadcasting -

b = np.where(a.argmin(0) >= np.arange(a.shape[0])[:,None],a,np.nan) 
idx = np.nanargmax(b,axis=0) 
out = a[idx,np.arange(a.shape[1])] 

採樣運行 -

In [38]: a 
Out[38]: 
array([[5, 4], 
     [4, 5], 
     [2, 2], 
     [6, 1], 
     [3, 7]]) 

In [39]: b = np.where(a.argmin(0) >= np.arange(a.shape[0])[:,None],a,np.nan) 
    ...: idx = np.nanargmax(b,axis=0) 
    ...: out = a[idx,np.arange(a.shape[1])] 
    ...: 

In [40]: idx 
Out[40]: array([0, 1]) 

In [41]: out 
Out[41]: array([5, 5]) 

或者,如果a只有正數,我們可以得到簡單地idx -

mask = a.argmin(0) >= np.arange(a.shape[0])[:,None] 
idx = (a*mask).argmax(0) 
+0

不錯!將其餘設置爲'np.nan'非常巧妙。有利於將來參考的問題!謝謝@Divakar。 – Kartik

+1

如果您好奇,我發佈了另一個答案。 – piRSquared

2

我知道我可以用的累積argmax @ajcr answered that question for me here

向量化版本回答
def ajcr(a): 
    m = np.maximum.accumulate(a) 
    x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1) 
    x[1:] *= m[:-1] < m[1:] 
    np.maximum.accumulate(x, axis=0, out=x) 
    # at this point x houses the cumulative argmax 
    # we slice that with a's argmin 
    return x[a.argmin(0), np.arange(a.shape[1])] 

def divakar(a): 
    b = np.where(a.argmin(0) >= np.arange(a.shape[0])[:,None],a,np.nan) 
    return np.nanargmax(b,axis=0) 

比較

a = np.random.randn(10000, 1000) 
(ajcr(a) == divakar(a)).all() 

True 

定時

import timeit 

results = pd.DataFrame(
    [], [10, 100, 1000, 10000], 
    pd.MultiIndex.from_product(
     [['divakar', 'ajcr'], [10, 100, 1000]])) 

for i, j in results.stack(dropna=False).index: 
    a = np.random.randn(i, j) 
    results.loc[i, ('divakar', j)] = \ 
     timeit.timeit(
      'divakar(a)', 
      setup='from __main__ import divakar, a', 
      number=10) 
    results.loc[i, ('ajcr', j)] = \ 
     timeit.timeit(
      'ajcr(a)', 
      setup='from __main__ import ajcr, a', 
      number=10) 

import matplotlib.pyplot as plt 

fig, axes = plt.subplots(2, 2, figsize=(10, 5)) 
for i, (name, group) in enumerate(results.stack().groupby(level=0)): 
    r, c = i // 2, i % 2 
    group.xs(name).plot.barh(ax=axes[r, c], title=name) 
fig.tight_layout() 

enter image description here

results 

enter image description here

+0

啊我不知道這兩個問題是連接的。另一個問題似乎很難解決,但你有一個可愛的解決方案。很高興看到其他方式來實現這一目標。 – Divakar