2015-12-31 59 views
7

我們有使用口罩的矢量numpy的get_pos_neg_bitwise功能= 132 20 192] 和的df.shape較慢(500e3,4),我們希望與numba加速。Numba 3倍比numpy的

from numba import jit 
import numpy as np 
from time import time 

def get_pos_neg_bitwise(df, mask): 
    """ 
    In [1]: print mask 
    [132 20 192] 

    In [1]: print df 
    [[ 1 162 97 41] 
    [ 0 136 135 171] 
    ..., 
    [ 0 245 30 73]] 

    """ 
    check = (np.bitwise_and(mask, df[:, 1:]) == mask).all(axis=1) 
    pos = (df[:, 0] == 1) & check 
    neg = (df[:, 0] == 0) & check 
    pos = np.nonzero(pos)[0] 
    neg = np.nonzero(neg)[0] 
    return (pos, neg) 

從@morningsun我們做這個numba版使用技巧:比numpy的一個(〜0.06S Vs的〜0,02s)慢

@jit(nopython=True) 
def numba_get_pos_neg_bitwise(df, mask): 
    posneg = np.zeros((df.shape[0], 2)) 
    for idx in range(df.shape[0]): 
     vandmask = np.bitwise_and(df[idx, 1:], mask) 

     # numba fail with # if np.all(vandmask == mask): 
     vandm_equal_m = 1 
     for i, val in enumerate(vandmask): 
      if val != mask[i]: 
       vandm_equal_m = 0 
       break 
     if vandm_equal_m == 1: 
      if df[idx, 0] == 1: 
       posneg[idx, 0] = 1 
      else: 
       posneg[idx, 1] = 1 
    pos = list(np.nonzero(posneg[:, 0])[0]) 
    neg = list(np.nonzero(posneg[:, 1])[0]) 
    return (pos, neg) 

但它仍然3次。

if __name__ == '__main__': 

    df = np.array(np.random.randint(256, size=(int(500e3), 4))) 
    df[:, 0] = np.random.randint(2, size=(1, df.shape[0])) # set target to 0 or 1 
    mask = np.array([132, 20, 192]) 

    start = time() 
    pos, neg = get_pos_neg_bitwise(df, mask) 
    msg = '==> pos, neg made; p={}, n={} in [{:.4} s] numpy' 
    print msg.format(len(pos), len(neg), time() - start) 

    start = time() 
    msg = '==> pos, neg made; p={}, n={} in [{:.4} s] numba' 
    pos, neg = numba_get_pos_neg_bitwise(df, mask) 
    print msg.format(len(pos), len(neg), time() - start) 
    start = time() 
    pos, neg = numba_get_pos_neg_bitwise(df, mask) 
    print msg.format(len(pos), len(neg), time() - start) 

我這麼想嗎?

In [1]: %run numba_test2.py 
==> pos, neg made; p=3852, n=3957 in [0.02306 s] numpy 
==> pos, neg made; p=3852, n=3957 in [0.3492 s] numba 
==> pos, neg made; p=3852, n=3957 in [0.06425 s] numba 
In [1]: 

回答

10

嘗試呼叫轉移到np.bitwise_and外循環,因爲numba不能做任何事情,以加快速度:

@jit(nopython=True) 
def numba_get_pos_neg_bitwise(df, mask): 
    posneg = np.zeros((df.shape[0], 2)) 
    vandmask = np.bitwise_and(df[:, 1:], mask) 

    for idx in range(df.shape[0]): 

     # numba fail with # if np.all(vandmask == mask): 
     vandm_equal_m = 1 
     for i, val in enumerate(vandmask[idx]): 
      if val != mask[i]: 
       vandm_equal_m = 0 
       break 
     if vandm_equal_m == 1: 
      if df[idx, 0] == 1: 
       posneg[idx, 0] = 1 
      else: 
       posneg[idx, 1] = 1 
    pos = np.nonzero(posneg[:, 0])[0] 
    neg = np.nonzero(posneg[:, 1])[0] 
    return (pos, neg) 

然後我得到的時序:

==> pos, neg made; p=3920, n=4023 in [0.02352 s] numpy 
==> pos, neg made; p=3920, n=4023 in [0.2896 s] numba 
==> pos, neg made; p=3920, n=4023 in [0.01539 s] numba 

所以現在numba比numpy快一點。

此外,它沒有造成巨大的差異,但在原始函數中返回numpy數組,而在numba版本中,您將posneg轉換爲列表。

一般來說,我猜測函數調用是由numpy函數控制的,而numba函數不能加快速度,而numpy版本的代碼已經使用了快速向量化例程。

更新:

你可以把它通過移除enumerate通話和索引直接進入陣列,而不是抓住一個切片更快。還分裂posneg成單獨的陣列有助於避免沿着非連續軸在存儲器切片:

@jit(nopython=True) 
def numba_get_pos_neg_bitwise(df, mask): 
    pos = np.zeros(df.shape[0]) 
    neg = np.zeros(df.shape[0]) 
    vandmask = np.bitwise_and(df[:, 1:], mask) 

    for idx in range(df.shape[0]): 

     # numba fail with # if np.all(vandmask == mask): 
     vandm_equal_m = 1 
     for i in xrange(vandmask.shape[1]): 
      if vandmask[idx,i] != mask[i]: 
       vandm_equal_m = 0 
       break 
     if vandm_equal_m == 1: 
      if df[idx, 0] == 1: 
       pos[idx] = 1 
      else: 
       neg[idx] = 1 
    pos = np.nonzero(pos)[0] 
    neg = np.nonzero(neg)[0] 
    return pos, neg 

而在一個IPython的筆記本定時:

%timeit pos1, neg1 = get_pos_neg_bitwise(df, mask) 
    %timeit pos2, neg2 = numba_get_pos_neg_bitwise(df, mask) 

​ 100 loops, best of 3: 18.2 ms per loop 
    100 loops, best of 3: 7.89 ms per loop