2016-03-03 70 views
6

從我讀過的書中,numba可以顯着加快python程序的速度。使用numba可以提高我的程序的時間效率嗎?使用numba來加速循環

import numpy as np 

def f_big(A, k, std_A, std_k, mean_A=10, mean_k=0.2, hh=100): 
    return (1/(std_A * std_k * 2 * np.pi)) * A * (hh/50) ** k * np.exp(-1*(k - mean_k)**2/(2 * std_k **2) - (A - mean_A)**2/(2 * std_A**2)) 

outer_sum = 0 
dk = 0.000001 
for k in np.arange(dk,0.4, dk): 
    inner_sum = 0 
    for A in np.arange(dk, 20, dk): 
     inner_sum += dk * f_big(A, k, 1e-5, 1e-5) 
    outer_sum += inner_sum * dk 

print outer_sum 

回答

5

是的,這是Numba真正適用的問題。我改變了dk的價值,因爲它對於一個簡單的演示來說是不明智的。下面是代碼:

import numpy as np 
import numba as nb 

def f_big(A, k, std_A, std_k, mean_A=10, mean_k=0.2, hh=100): 
    return (1/(std_A * std_k * 2 * np.pi)) * A * (hh/50) ** k * np.exp(-1*(k - mean_k)**2/(2 * std_k **2) - (A - mean_A)**2/(2 * std_A**2)) 

def func(): 
    outer_sum = 0 
    dk = 0.01 #0.000001 
    for k in np.arange(dk, 0.4, dk): 
     inner_sum = 0 
     for A in np.arange(dk, 20, dk): 
      inner_sum += dk * f_big(A, k, 1e-5, 1e-5) 
     outer_sum += inner_sum * dk 

    return outer_sum 

@nb.jit(nopython=True) 
def f_big_nb(A, k, std_A, std_k, mean_A=10, mean_k=0.2, hh=100): 
    return (1/(std_A * std_k * 2 * np.pi)) * A * (hh/50) ** k * np.exp(-1*(k - mean_k)**2/(2 * std_k **2) - (A - mean_A)**2/(2 * std_A**2)) 

@nb.jit(nopython=True) 
def func_nb(): 
    outer_sum = 0 
    dk = 0.01 #0.000001 
    X = np.arange(dk, 0.4, dk) 
    Y = np.arange(dk, 20, dk) 
    for i in xrange(X.shape[0]): 
     k = X[i] # faster to do lookup than iterate over an array directly 
     inner_sum = 0 
     for j in xrange(Y.shape[0]): 
      A = Y[j] 
      inner_sum += dk * f_big_nb(A, k, 1e-5, 1e-5) 
     outer_sum += inner_sum * dk 

    return outer_sum 

然後計時:

In [7]: np.allclose(func(), func_nb()) 
Out[7]: True 

In [8]: %timeit func() 
1 loops, best of 3: 222 ms per loop 

In [9]: %timeit func_nb() 
The slowest run took 419.10 times longer than the fastest. This could mean that an intermediate result is being cached 
1000 loops, best of 3: 362 µs per loop 

所以numba版本大約是在我的筆記本電腦快600倍。

+0

也許nitpicky,而不是使用''@ nb.jit(nopython = True)''你可以使用''nb.njit''而不''nopython = True''。 – MSeifert

+2

@ MSeifert我傾向於習慣於使用這種形式,因爲我會經常參數化它,所以我可以在測試過程中輕鬆地來回切換 – JoshAdel