2017-09-04 142 views
1

可以通過Python腳本將一個cdef Cython函數傳遞給另一個(python def)cython函數嗎?通過python接口傳遞cython函數

小例子:

test_module.pyx

cpdef min_arg(f, int N): 
    cdef double x = 100000. 
    cdef int best_i = -1 

    for i in range(N): 
     if f(i) < x: 
      x = f(i) 
      best_i = i 
    return best_i 

def py_f(x): 
    return (x-5)**2 

cdef public api double cy_f(double x): 
    return (x-5)**2 

test.py

import pyximport; pyximport.install() 
import testmodule 

testmodule.min_arg(testmodule.py_f, 100) 

這工作不錯,但我希望能夠也做

testmodule.min_arg(testmodule.cy_f, 100) 

從一個test.py,具有cython的速度(每個f(i)調用沒有Python開銷)。但顯然,Python不知道cy_f,因爲它不是defcpdef宣佈。

我希望這樣的事情存在:

from scipy import LowLevelCallable 
cy_f = LowLevelCallable.from_cython(testmodule, 'cy_f') 
testmodule.min_arg(cy_f, 100) 

但是這給TypeError: 'LowLevelCallable' object is not callable

預先感謝您。

回答

1

LowLevelCallable是一類必須被底層Python模塊接受的函數。這項工作已經做了幾個模塊,包括正交常規scipy.integrate.quad

如果您希望使用相同的包裝方法,則必須經過SciPy的例程,使用它,如scipy.ndimage.generic_filter1dscipy.integrate.quad。但是,代碼位於編譯擴展中。

另一種方法是,如果您的問題在回調中定義合理,則自行實施。我已經在我的代碼中的一個做這一點,所以我張貼的鏈接簡單:

  1. .pxd文件中,我定義了接口cyfunc_d_dhttps://github.com/pdebuyl/skl1/blob/master/skl1/core.pxd
  2. 我可以重新使用該界面中的「基礎「cython模塊https://github.com/pdebuyl/skl1/blob/master/skl1/euler.pyx以及」用戶定義「模塊。

最後的代碼,使簡單的 「用Cython,用Cython」 要求,同時允許對象的傳球在用Cython水平

我適應代碼到你的問題:

  1. test_interface.pxd

    cdef class cyfunc:                               
        cpdef double f(self, double x)                           
    
    cdef class pyfunc(cyfunc):                             
        cdef object py_f                              
        cpdef double f(self, double x)                           
    
  2. test_interface.pyx

    cdef class cyfunc: 
        cpdef double f(self, double x): 
         return 0 
        def __cinit__(self): 
         pass 
    
    
    cdef class pyfunc(cyfunc): 
        cpdef double f(self, double x): 
         return self.py_f(x) 
        def __init__(self, f): 
         self.py_f = f 
    
  3. setup.py

    from setuptools import setup, Extension                          
    from Cython.Build import cythonize                           
    
    setup(                                  
        ext_modules=cythonize((Extension('test_interface', ["test_interface.pyx"]),                
              Extension('test_module', ["test_module.pyx"]))                 
            )                              
    )                                   
    
  4. test_module.pyx

    from test_interface cimport cyfunc, pyfunc                         
    
    cpdef min_arg(f, int N):                             
        cdef double x = 100000.                             
        cdef int best_i = -1                             
        cdef int i                                
        cdef double current_value                            
    
        cdef cyfunc py_f                              
    
        if isinstance(f, cyfunc):                            
         py_f = f                               
         print('cyfunc')                              
        elif callable(f):                              
         py_f = pyfunc(f)                             
         print('no cyfunc')                             
        else:                                 
         raise ValueError("f should be a callable or a cyfunc")                    
    
        for i in range(N):                              
         current_value = py_f.f(i)                           
         if current_value < x:                            
          x = current_value                            
          best_i = i                              
        return best_i                               
    
    def py_f(x):                                
        return (x-5)**2                               
    
    cdef class cy_f(cyfunc):                             
        cpdef double f(self, double x):                           
         return (x-5)**2                              
    

要使用:

python3 setup.py build_ext --inplace 
python3 -c 'import test_module ; print(test_module.min_arg(test_module.cy_f(), 10))' 
python3 -c 'import test_module ; print(test_module.min_arg(test_module.py_f, 10))'