2012-08-29 47 views
1

我試圖邊緣化一個軸上的數組,並檢查一維峯值是否出現在與原始二維峯值相同的相關索引處。在什麼情況下(x的形式)下面的斷言是否失敗?numpy array - argmax v argmax的總和的unravel_index

def check(x,axis=None): 
    import numpy 
    m = numpy.sum(x, axis=axis) 
    v,w = numpy.unravel_index(numpy.argmax(x), x.shape) 
    assert(v==numpy.argmax(m)) 
    return 

對於x=numpy.array(range(15)).reshape(5,3)check(x,axis=0)引發錯誤,但check(x,axis=1)沒有。我看不出爲什麼有人提出AssertionError - 我是不是很愚蠢?

+0

您是否嘗試打印'v'和numpy.argmax(m)的值?也許實際值會給你一個線索 – Dhara

+0

當axis = None時斷言失敗嗎?這可能是因爲sum函數將數組中的所有值相加,給出m [不是數組]的單個值。這個argmax總是0 – Dhara

回答

1

您正在檢查拆散索引的錯誤座標。取而代之的

v,w=numpy.unravel_index(numpy.argmax(x),x.shape) 
assert(v==numpy.argmax(m)) 

你可能想

vw = numpy.unravel_index(numpy.argmax(x),x.shape) 
assert vw[1 - axis] == numpy.argmax(m) 

或許

v,w=numpy.unravel_index(numpy.argmax(x),x.shape) 
assert (v if axis == 1 else w) == numpy.argmax(m) 
1

axis參數的值是非常關鍵的。

隨着x一個(N, M)陣列,m=np.sum(x, axis=axis)會給你

  • 標如果axis=None;
  • a M array if axis=0;
  • a N array if axis=1

因此,你的np.argmax(m)將總是0是如果axis=None,或者如果axis=0(相應的axis=10M(相應的N)之間的整數。

但是,您的(v, w) = np.unravel_index(...)將始終會給您v作爲0到N之間的整數。

正如你所看到的,與axis=0,潛力值的m的範圍是不一樣的,作爲v,而它是axis=1

所以,比較mv如果axis=1,或w如果axis=0(@ ecatmur的答案告訴您如何)。