我第一次使用cython來獲得一些函數的速度。該函數採用方形矩陣A
浮點數並輸出單個浮點數。它是計算功能是permanent of a matrix這個cython代碼可以優化嗎?
當A爲30×30我的代碼目前需要大約60秒我的電腦上。
在下面的代碼中,我已經爲維基頁面上的永久實現了Balasubramanian-Bax/Franklin-Glynn公式。我稱之爲矩陣M.
代碼的一個複雜部分是數組f,它用於保存下一個要在數組d中翻轉的位置的索引。數組d保存的值是+ -1。循環中f和j的操作只是快速更新格雷碼的聰明方法。
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE_int = np.int
ctypedef np.int_t DTYPE_int_t
DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def permfunc(np.ndarray [DTYPE_float_t, ndim =2, mode='c'] M):
cdef int n = M.shape[0]
cdef np.ndarray[DTYPE_float_t, ndim =1, mode='c' ] d = np.ones(n, dtype=DTYPE_float)
cdef int j = 0
cdef int s = 1
cdef np.ndarray [DTYPE_int_t, ndim =1, mode='c'] f = np.arange(n, dtype=DTYPE_int)
cdef np.ndarray [DTYPE_float_t, ndim =1, mode='c'] v = M.sum(axis=0)
cdef DTYPE_float_t p = 1
cdef int i
cdef DTYPE_float_t prod
for i in range(n):
p *= v[i]
while (j < n-1):
for i in range(n):
v[i] -= 2*d[j]*M[j, i]
d[j] = -d[j]
s = -s
prod = 1
for i in range(n):
prod *= v[i]
p += s*prod
f[0] = 0
f[j] = f[j+1]
f[j+1] = j+1
j = f[0]
return p/2**(n-1)
我已經使用了我在cython教程中找到的所有簡單優化。我不得不承認我不完全理解的一些方面。例如,如果我將數組d
整數,因爲這些值只有+1,所以代碼運行速度大約慢了10%,所以我將其作爲float64s。
有什麼我可以做的,以加快代碼?
這是用Cython -a的結果。正如你可以看到循環中的所有內容都被編譯爲C,所以基本的優化已經奏效。
這裏是numpy的同樣的功能比我目前用Cython版本慢100倍以上。
def npperm(M):
n = M.shape[0]
d = np.ones(n)
j = 0
s = 1
f = np.arange(n)
v = M.sum(axis=0)
p = np.prod(v)
while (j < n-1):
v -= 2*d[j]*M[j]
d[j] = -d[j]
s = -s
prod = np.prod(v)
p += s*prod
f[0] = 0
f[j] = f[j+1]
f[j+1] = j+1
j = f[0]
return p/2**(n-1)
時序更新
這裏是定時我用Cython版本,該版本numpy的和romeric的改進到用Cython代碼(使用IPython中)。我爲可重複性設定了種子。
from scipy.stats import ortho_group
import pyximport; pyximport.install()
import permlib # This loads in the functions from permlib.pyx
import numpy as np; np.random.seed(7)
M = ortho_group.rvs(23) #Creates a random orthogonal matrix
%timeit permlib.npperm(M) # The numpy version
1 loop, best of 3: 44.5 s per loop
%timeit permlib.permfunc(M) # The cython version
1 loop, best of 3: 273 ms per loop
%timeit permlib.permfunc_modified(M) #romeric's improvement
10 loops, best of 3: 198 ms per loop
M = ortho_group.rvs(28)
%timeit permlib.permfunc(M) # The cython version run on a 28x28 matrix
1 loop, best of 3: 15.8 s per loop
%timeit permlib.permfunc_modified(M) # romeric's improvement run on a 28x28 matrix
1 loop, best of 3: 12.4 s per loop
能否用Cython代碼可以加速呢?
我用gcc和CPU是AMD FX 8350.
是:你可以問這個上[codereview.se。 – usr2564301
@RadLexus謝謝。然而,看起來在這裏很少有cython問題。曾經有過30次! – eleanora
@eleanora:這樣的推理使得這個數字一直很低。 – Mat