2015-12-28 100 views
4

我在此代碼段numba布爾陣列

from numba import jit 
import numpy as np 
from time import time 
db = np.array(np.random.randint(2, size=(400e3, 4)), dtype=bool) 
out = np.zeros((int(400e3), 1)) 

@jit() 
def check_mask(db, out, mask=[1, 0, 1]): 
    for idx, line in enumerate(db): 
     target, vector = line[0], line[1:] 
     if (mask == np.bitwise_and(mask, vector)).all(): 
      if target == 1: 
       out[idx] = 1 
    return out 

st = time() 
res = check_mask(db, out, [1, 0, 1]) 
print 'with jit: {:.4} sec'.format(time() - st) 

隨着numba @jit()裝飾這段代碼的運行速度較慢試圖numba!

  • 沒有JIT:3.16秒
  • 與JIT:3.81秒

只是爲了幫助理解這段代碼的目的更好:

db = np.array([   # out value for mask = [1, 0, 1] 
    # target, vector  # 
     [1,  1, 0, 1], # 1 
     [0,  1, 1, 1], # 0 (fit to mask but target == 0) 
     [0,  0, 1, 0], # 0 
     [1,  1, 0, 1], # 1 
     [0,  1, 1, 0], # 0 
     [1,  0, 0, 0], # 0 
     ]) 
+0

查看'array_equal'代碼,如我最近的回答http://stackoverflow.com/a/34486522/901925中所示。 – hpaulj

+0

Thanks @hpaulj我剛剛更新了片段,以考慮到您的評論 – user3313834

回答

2

Numba有兩種編譯模式jit:nopython模式和對象模式。 Nopython模式(默認)僅支持一組有限的Python和Numpy功能,請參閱the docs for your version。如果jitted函數包含不支持的代碼,Numba必須回退到對象模式,這要慢得多。

我不確定objcet模式與純Python相比是否應該加快速度,但總是希望使用nopython模式。爲了用於確保nopython模式,指定nopython=True並堅持非常基本的代碼(經驗法則:寫出所有的循環,並且只使用標量和numpy的數組):

@jit(nopython=True) 
def check_mask_2(db, out, mask=np.array([1, 0, 1])): 
    for idx in range(db.shape[0]): 
     if db[idx,0] != 1: 
      continue 
     check = 1 
     for j in range(db.shape[1]): 
      if mask[j] and not db[idx,j+1]: 
       check = 0 
       break 
     out[idx] = check 
    return out 

寫出內環明確也有一旦條件失敗,我們可以立即擺脫它的好處。

時序:

%time _ = check_mask(db, out, np.array([1, 0, 1])) 
# Wall time: 1.91 s 
%time _ = check_mask_2(db, out, np.array([1, 0, 1])) 
# Wall time: 310 ms # slow because of compilation 
%time _ = check_mask_2(db, out, np.array([1, 0, 1])) 
# Wall time: 3 ms 

BTW,功能也很容易與numpy的,這給出了一個體面的速度矢量:

def check_mask_vectorized(db, mask=[1, 0, 1]): 
    check = (db[:,1:] == mask).all(axis=1) 
    out = (db[:,0] == 1) & check 
    return out 

%time _ = check_mask_vectorized(db, [1, 0, 1]) 
# Wall time: 14 ms 
+0

感謝您對我的建議使用更復雜但相似的問題,但我沒有成功加快使用numba:http: //stackoverflow.com/q/34544210/3313834 – user3313834

1

我會建議取下numpy的通話從內部循環到array_equal。 numba不一定足夠聰明,可以將它變成一段內聯C;如果它不能取代這個呼叫,你的功能的主要成本仍然可比,這將解釋你的結果。

儘管numba可以推理相當數量的numpy結構,但只有C樣式的代碼可以依賴於加速的numpy數組。

+0

是的,這就是爲什麼我已經刪除了array_equal,取而代之的是np.bitwise_and(mask,vector) – user3313834

+0

我不確定這是一個重大的區別numba。您可能需要手動執行循環遍歷所有掩碼組件,以避免內部循環中的python調用 –

3

或者,你可以嘗試pythran(免責聲明:我是一個developper的pythran)。

使用單個註釋,它編譯以下代碼

#pythran export check_mask(bool[][], bool[]) 

import numpy as np 
def check_mask(db, out, mask=[1, 0, 1]): 
    for idx, line in enumerate(db): 
     target, vector = line[0], line[1:] 
     if (mask == np.bitwise_and(mask, vector)).all(): 
      if target == 1: 
       out[idx] = 1 
    return out 

一起pythran check_call.py的呼叫。

而且根據timeit,所產生的本機模塊運行非常快:

python -m timeit -s 'n=10e3 ; import numpy as np;db = np.array(np.random.randint(2, size=(n, 4)), dtype=bool); out = np.zeros(int(n),dtype=bool); from eq import check_mask' 'check_mask(db, out)' 

告訴我CPython的版本136ms運行,而pythran編譯版本450us運行。

+0

我到達這個問題http://stackoverflow.com/q/35240168/3313834用pythran做到這一點。 BTW pythran值得在stackoverflow上的標籤 – user3313834