2013-06-03 85 views
2

我想學習cython;但是,我一定在做錯事。這小小的測試代碼運行速度比我的矢量化numpy版本慢50倍左右。有人能告訴我爲什麼我的cython比我的python慢​​嗎?謝謝。cython運行速度比numpy慢,用於距離計算

該代碼計算R^3中的一個點,loc,和R^3中點的數組之間的距離,點。

import numpy as np 
cimport numpy as np 
import cython 
cimport cython 

DTYPE = np.float64 
ctypedef np.float64_t DTYPE_t 
@cython.boundscheck(False) # turn of bounds-checking for entire function 
@cython.wraparound(False) 
@cython.nonecheck(False) 
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc): 
    cdef unsigned int i 
    cdef unsigned int L = points.shape[0] 
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L) 
    for i in xrange(0,L): 
     d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2] - loc[2])**2) 
    return d 

這是與它進行比較的numpy代碼。

from numpy import * 
N = 1e6 
points = random.uniform(0,1,(N,3)) 
loc = random.uniform(0,1,(3)) 

def distMeasureNumpy(points,loc): 
    d = points - loc 
    d = sqrt(sum(d*d,axis=1)) 
    return d 

numpy/python版本大約需要44ms,而cython版本大約需要2秒。我在mac osx上運行python 2.7。我使用ipython的%timeit命令來計時這兩個函數。

+0

我沒有看到任何明顯的錯誤與cython版本(我驚訝它是如此緩慢)。但是,你不打算用cython擊敗向量化的numpy表達式。對於無法被矢量化的操作,Cython是最好的(並且通常相當不錯)。另外,你可以使用'd = np.hypot(* d.T)'來稍微加快你的numpy版本。 –

+0

你運行過'cython -a your_code.pyx'並查看過'your_code.html'嗎?這是一種檢查由cython生成的C代碼的便捷方式,並且可以找出已經轉換爲C的數量以及在Python級別還有多少工作量。 –

回答

5

np.sqrt這是一個Python函數調用的調用正在破壞您的性能您正在計算標量浮點值的平方根,因此您應該使用C數學庫中的sqrt函數。以下是您的代碼的修改版本:

import numpy as np 
cimport numpy as np 
import cython 
cimport cython 

from libc.math cimport sqrt 

DTYPE = np.float64 
ctypedef np.float64_t DTYPE_t 
@cython.boundscheck(False) # turn of bounds-checking for entire function 
@cython.wraparound(False) 
@cython.nonecheck(False) 
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, 
         np.ndarray[DTYPE_t, ndim=1] loc): 
    cdef unsigned int i 
    cdef unsigned int L = points.shape[0] 
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L) 
    for i in xrange(0,L): 
     d[i] = sqrt((points[i,0] - loc[0])**2 + 
        (points[i,1] - loc[1])**2 + 
        (points[i,2] - loc[2])**2) 
    return d 

以下演示了性能改進。您的原始代碼是模塊check_speed_original中,而修改後的版本是check_speed

In [11]: import check_speed_original 

In [12]: import check_speed 

設置測試數據:

In [13]: N = 10**6 

In [14]: points = random.uniform(0,1,(N,3)) 

In [15]: loc = random.uniform(0,1,(3,)) 

原來的版本需要在我的電腦上1.26秒:

In [16]: %timeit check_speed_original.distMeasureCython(points, loc) 
1 loops, best of 3: 1.26 s per loop 

修改後的版本需要4.47 毫秒

In [17]: %timeit check_speed.distMeasureCython(points, loc) 
100 loops, best of 3: 4.47 ms per loop 

萬一有人擔心,結果可能會有所不同:

In [18]: d1 = check_speed.distMeasureCython(points, loc) 

In [19]: d2 = check_speed_original.distMeasureCython(points, loc) 

In [20]: np.all(d1 == d2) 
Out[20]: True 
+0

它的工作原理!謝謝。我正在獲得上面提到的運行時間。感謝關於html的提示。這是我的第一個問題,我現在應該關閉它,現在它已經解決了嗎?感謝WW的幫助。 – plancherel

3

前面已經提到,它是在代碼中調用numpy.sqrt。不過,我認爲不需要使用cdef extern,因爲Cython已經提供了這些基本的C/C++庫。 (見the docs)。所以你可以像這樣輸入:

from libc.math cimport sqrt 

只是爲了擺脫開銷。

+0

好點,謝謝。我已經更新了我的答案。 –