我在此代碼段numba布爾陣列
from numba import jit
import numpy as np
from time import time
db = np.array(np.random.randint(2, size=(400e3, 4)), dtype=bool)
out = np.zeros((int(400e3), 1))
@jit()
def check_mask(db, out, mask=[1, 0, 1]):
for idx, line in enumerate(db):
target, vector = line[0], line[1:]
if (mask == np.bitwise_and(mask, vector)).all():
if target == 1:
out[idx] = 1
return out
st = time()
res = check_mask(db, out, [1, 0, 1])
print 'with jit: {:.4} sec'.format(time() - st)
隨着numba @jit()裝飾這段代碼的運行速度較慢試圖numba!
- 沒有JIT:3.16秒
- 與JIT:3.81秒
只是爲了幫助理解這段代碼的目的更好:
db = np.array([ # out value for mask = [1, 0, 1]
# target, vector #
[1, 1, 0, 1], # 1
[0, 1, 1, 1], # 0 (fit to mask but target == 0)
[0, 0, 1, 0], # 0
[1, 1, 0, 1], # 1
[0, 1, 1, 0], # 0
[1, 0, 0, 0], # 0
])
查看'array_equal'代碼,如我最近的回答http://stackoverflow.com/a/34486522/901925中所示。 – hpaulj
Thanks @hpaulj我剛剛更新了片段,以考慮到您的評論 – user3313834