2016-12-09 58 views
2

由於性能的原因,除了NumPy之外,我已經開始使用Numba了。我的Numba算法正在工作,但我有一種感覺,它應該更快。有一點是減緩它。以下是代碼片段:在numba中的性能嵌套循環

@nb.njit 
def rfunc1(ws, a, l): 
    gn = a**l 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 
        if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
        numpy.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 
        if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
        numpy.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 

在我看來,if命令減緩下來。有沒有更好的辦法? (我試圖在這裏實現是與先前發佈的問題是什麼:Count possibilites for single crossoversws是尺寸含0(gn, l)的NumPy的陣列的和1

+0

你意識到這種規模可怕地與'gn'的大小...? –

+0

是的,l的最大大小是9,a總是2 – HighwayJohn

+0

你在Python 2還是3? –

回答

2

鑑於希望確保所有項目都是平等的邏輯,你可以利用這樣一個事實,即如果有任何不相等的事實,則可以將計算短路(即停止比較)。我稍微修改原來的功能,以便(1)你不要重複相同的比較兩次,和(2)值Y在所有的嵌套循環,從而有可能進行比較的回報:

@nb.njit 
def rfunc1(ws, a, l): 
    gn = a**l 
    ysum = 0 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 
        if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 
         ysum += 1 

    return ysum 


@nb.njit 
def rfunc2(ws, a, l): 
    gn = a**l 
    ysum = 0 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 

        incr_y = True 
        for j in range(i): 
         if ws[x1,j] != ws[x2,j]: 
          incr_y = False 
          break 

        if incr_y is True: 
         for j in range(i,l): 
          if ws[x1,j] != ws[x3,j]: 
           incr_y = False 
           break 
        if incr_y is True: 
         y += 1 
         ysum += 1 
    return ysum 

我不知道完整的功能是什麼樣子,但希望這可以幫助你開始正確的道路。

現在對於一些計時:

l = 7 
a = 2 
gn = a**l 
ws = np.random.randint(0,2,size=(gn,l)) 
In [23]: 

%timeit rfunc1(ws, a , l) 
1 loop, best of 3: 2.11 s per loop 


%timeit rfunc2(ws, a , l) 
1 loop, best of 3: 39.9 ms per loop 

In [27]: rfunc1(ws, a , l) 
Out[27]: 131919 

In [30]: rfunc2(ws, a , l) 
Out[30]: 131919 

這就給了你50倍的加速。

+0

如何在'nopython = True'中使用'jit'? – pbreach

+0

'njit'相當於'jit(nopython = True)' – JoshAdel

+0

非常感謝! :) – HighwayJohn

2

而不只是「有感覺」在您的瓶頸,爲什麼不輪廓你的代碼,並找到究竟在哪裏呢?

性能分析的第一個目標是測試一個代表性的系統,以確定什麼是緩慢的(或使用太多RAM,或導致太多的磁盤I/O或網絡I/O)。

性能分析通常會增加開銷(10倍到100倍的減速可能是典型的),您仍然希望儘可能使用類似於實際情況的代碼。提取測試用例並隔離您需要測試的系統部分。最好是已經寫入它自己的一組模塊中。基本技術包括IPython中的%timeit魔術,time.time(),timing decorator(請參見下面的示例)。您可以使用這些技術來了解語句和函數的行爲。

然後你有cProfile這將給你一個問題的高層次的看法,所以你可以引導你的注意力到關鍵的功能。

接下來,看看line_profiler,這將逐行分析您選擇的功能。結果將包括每行被調用的次數以及每行所用時間的百分比。這正是您瞭解緩慢運行以及爲什麼需要的信息。

perf stat幫助您理解最終在CPU上執行的指令的數量以及CPU的高速緩存利用率。這允許對矩陣操作進行高級調整。

heapy可以跟蹤Python內存中的所有對象。這非常適合尋找奇怪的內存泄漏。如果您使用的是長時間運行的系統,那麼 然後dowser會引起您的興趣:它允許您通過Web瀏覽器界面在長時間運行的過程中反思活動對象。

爲了幫助您理解RAM使用率高的原因,請查看memory_profiler.這對於跟蹤隨時間推移的RAM使用情況特別有用,因爲您可以向同事(或您自己)解釋爲什麼某些功能使用的RAM多於預期。

例:定義一個裝飾來自動定時測量

from functools import wraps 

def timefn(fn): 
    @wraps(fn) 
    def measure_time(*args, **kwargs): 
     t1 = time.time() 
     result = fn(*args, **kwargs) 
     t2 = time.time() 
     print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds") 
     return result 
    return measure_time 

@timefn 
def your_func(var1, var2): 
    ... 

有關詳細信息,我建議讀High performance Python(米莎戈雷利克;伊恩Ozsvald),從該上述被採購。

+0

這些都是很好的**一般**建議,但沒有真正適用於這個問題。例如,你不能在numba函數內使用'line_profiler',也不能在'nopython'模式下調用'time.time'。最初的問題是關於改進一個函數的性能(大概已經確定爲熱點),這個函數是在numba中編碼的。通常在那裏,你必須對Numba可以轉換成高性能的llvm代碼有一個直覺,許多通用技術都無法解決這個問題。 – JoshAdel

+0

@JoshAdel:我想向OP建議,不要猜測瓶頸在哪裏,但可以通過配置文件來確定。爲了未來讀者的利益,我試圖使分析選項稍微完整(即使並非所有的都適用於OP的情況)。 – boardrider