2013-04-26 30 views
1

我想使用weave.blitz改善以下numpy的代碼的性能:閃電代碼產生不同的輸出

def fastIteration(self): 
    g = self.grid 
    nx,ny = g.ux.shape 

    uxold = g.old_ux 
    ux = g.ux 
    ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2]) 

    g.setBC() 
    g.old_ux = ux.copy() 

在此代碼g是計算網格。它由兩個不同的領域ux和uxold組成。舊的只是用於臨時存儲變量。在完整的代碼中,大約95%的運行時間用於fastIteration方法,因此即使簡單的性能增益也會顯着減少執行此代碼的時間。

的numpy的方法的輸出看起來好像:

numpy result

由於這個代碼是我的瓶頸,我想用編織熱捧提高速度。這種方法看起來像:

def blitzIteration(self): 
    ### does not work correct so far 
    g = self.grid 
    nx,ny = g.ux.shape 

    uxold = g.old_ux 
    ux = g.ux 
    expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])" 
    weave.blitz(expr, check_size=0) 
    g.setBC() 
    g.old_ux = ux.copy() 

然而,這並不產生正確的輸出:(fixed轉載,提交併有一個關於實際錯誤的詳細信息那裏) output for blitz code

回答

2

它看起來像在weave.blitz的錯誤。

我認爲這是奇怪的寫0:而不是更短的:得到一個完整的切片,所以我取代了所有這些片和voilà,它的工作。

我真的不知道哪裏的錯誤所在,但weave.blitz產生的expr_code略有不同:

  • 當使用0:

    ipdb> expr_code 
    'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n' 
    
  • 當使用:

    ipdb> expr_code 
    'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n' 
    

因此,blitz::Range(0,_end)變成_all並且它們的行爲方式不同。

爲方便起見,下面是一個完整的腳本,它重現了問題,只會在問題存在時成功。

import numpy as np 
from scipy.weave import blitz 


def test_blitz_bug(N=4): 
    ReI = 1.2 
    ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N)) 
    uxold = np.random.randn(N, N) 
    ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2]) 
    expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])' 
    expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])' 
    blitz(expr_buggy) 
    blitz(expr_not_buggy) 
    assert not np.allclose(ux_blitz_buggy, ux_np) 
    assert np.allclose(ux_blitz_not_buggy, ux_np) 

if __name__ == '__main__': 
    test_blitz_bug() 
+1

@jordeca:這裏是: '$蟒蛇blitz_bug.py' '$蟒蛇-c 「進口SciPy的;打印SciPy的.__版本__」' 0.13.0.dev-639ef30 '$蟒蛇 - c「import numpy; print numpy .__ version __」' 1.7.1 '$ uname -a' Linux ratatoskr 2.6.32-45-generic#104-Ubuntu SMP Tue Feb 19 21:20:09 UTC 2013 x86_64 GNU/Linux – 2013-04-26 21:58:43

+0

@ Zhenya謝謝! – jorgeca 2013-04-27 14:13:22