2017-06-03 98 views
0

我有一個函數,我想用numba進行編譯,但是我需要計算該函數內部的階乘。不幸的是numba不支持math.factorial在numba nopython函數中計算階乘的最快方法

import math 
import numba as nb 

@nb.njit 
def factorial1(x): 
    return math.factorial(x) 

factorial1(10) 
# UntypedAttributeError: Failed at nopython (nopython frontend) 

我看到,它支持math.gamma(可以用來計算階乘)表示「整數值,但是違背了真正的math.gamma功能它沒有返回花車「:

@nb.njit 
def factorial2(x): 
    return math.gamma(x+1) 

factorial2(10) 
# 3628799.9999999995 <-- not exact 

math.gamma(11) 
# 3628800.0 <-- exact 

,它的緩慢相比math.factorial

%timeit factorial2(10) 
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 
%timeit math.factorial(10) 
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

所以我決定定義自己的功能:

@nb.njit 
def factorial3(x): 
    n = 1 
    for i in range(2, x+1): 
     n *= i 
    return n 

factorial3(10) 
# 3628800 

%timeit factorial3(10) 
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

它仍然比math.factorial慢,但是它比基於math.gamma功能numba更快的值是「精確」。

所以我正在尋找最快的方法來計算一個正整數(< = 20;爲了避免溢出)factorial在nopython的numba函數中。

+2

如果你只關心整數「0..20」的階乘因子,那麼查找表可能值得檢查速度。 –

+0

Arrrgggh,在我以前的評論中,我寫了*你的*我應該寫*你是*。或*如果您唯一的擔心是...... * –

+0

您可以嘗試重新實現numba中的python方法 - 它會通過一些額外的步驟來以特定方式對乘法進行排序 - https://github.com/python/ cpython/blob/3.6/Modules/mathmodule.c#L1275 – chrisb

回答

1

對於值< = 20,python正在使用查找表,正如評論中所建議的那樣。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([ 
    1, 1, 2, 6, 24, 120, 720, 5040, 40320, 
    362880, 3628800, 39916800, 479001600, 
    6227020800, 87178291200, 1307674368000, 
    20922789888000, 355687428096000, 6402373705728000, 
    121645100408832000, 2432902008176640000], dtype='int64') 

@nb.jit 
def fast_factorial(n): 
    if n > 20: 
     raise ValueError 
    return LOOKUP_TABLE[n] 

從Python中叫它比Python版本稍慢由於numba調度開銷。

In [58]: %timeit math.factorial(10) 
10000000 loops, best of 3: 79.4 ns per loop 

In [59]: %timeit fast_factorial(10) 
10000000 loops, best of 3: 173 ns per loop 

但是在另一個numba函數中調用可以更快。

def loop_python(): 
    for i in range(10000): 
     for n in range(21): 
      math.factorial(n) 

@nb.njit 
def loop_numba(): 
    for i in range(10000): 
     for n in range(21): 
      fast_factorial(n) 

In [65]: %timeit loop_python() 
10 loops, best of 3: 36.7 ms per loop 

In [66]: %timeit loop_numba() 
10000000 loops, best of 3: 73.6 ns per loop 
+0

Numba做了積極的循環優化,所以如果你不保存'fast_factorial'的結果,它甚至不會執行循環。 – MSeifert