我想學習cython;但是,我一定在做錯事。這小小的測試代碼運行速度比我的矢量化numpy版本慢50倍左右。有人能告訴我爲什麼我的cython比我的python慢嗎?謝謝。cython運行速度比numpy慢,用於距離計算
該代碼計算R^3中的一個點,loc,和R^3中點的數組之間的距離,點。
import numpy as np
cimport numpy as np
import cython
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc):
cdef unsigned int i
cdef unsigned int L = points.shape[0]
cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
for i in xrange(0,L):
d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2] - loc[2])**2)
return d
這是與它進行比較的numpy代碼。
from numpy import *
N = 1e6
points = random.uniform(0,1,(N,3))
loc = random.uniform(0,1,(3))
def distMeasureNumpy(points,loc):
d = points - loc
d = sqrt(sum(d*d,axis=1))
return d
numpy/python版本大約需要44ms,而cython版本大約需要2秒。我在mac osx上運行python 2.7。我使用ipython的%timeit命令來計時這兩個函數。
我沒有看到任何明顯的錯誤與cython版本(我驚訝它是如此緩慢)。但是,你不打算用cython擊敗向量化的numpy表達式。對於無法被矢量化的操作,Cython是最好的(並且通常相當不錯)。另外,你可以使用'd = np.hypot(* d.T)'來稍微加快你的numpy版本。 –
你運行過'cython -a your_code.pyx'並查看過'your_code.html'嗎?這是一種檢查由cython生成的C代碼的便捷方式,並且可以找出已經轉換爲C的數量以及在Python級別還有多少工作量。 –