2014-10-05 110 views
3

使用Cython,有沒有辦法編寫快速通用函數,這些函數適用於具有不同維數的數組?例如對於去混疊功能,這個簡單的例子:通用函數迭代n-D數組

import numpy as np 
cimport numpy as np 

ctypedef np.uint8_t DTYPEb_t 
ctypedef np.complex128_t DTYPEc_t 


def dealiasing1D(DTYPEc_t[:, :] data, 
       DTYPEb_t[:] where_dealiased): 
    """Dealiasing data for 1D solvers.""" 
    cdef Py_ssize_t ik, i0, nk, n0 

    nk = data.shape[0] 
    n0 = data.shape[1] 

    for ik in range(nk): 
     for i0 in range(n0): 
      if where_dealiased[i0]: 
       data[ik, i0] = 0. 


def dealiasing2D(DTYPEc_t[:, :, :] data, 
       DTYPEb_t[:, :] where_dealiased): 
    """Dealiasing data for 2D solvers.""" 
    cdef Py_ssize_t ik, i0, i1, nk, n0, n1 

    nk = data.shape[0] 
    n0 = data.shape[1] 
    n1 = data.shape[2] 

    for ik in range(nk): 
     for i0 in range(n0): 
      for i1 in range(n1): 
       if where_dealiased[i0, i1]: 
        data[ik, i0, i1] = 0. 


def dealiasing3D(DTYPEc_t[:, :, :, :] data, 
       DTYPEb_t[:, :, :] where_dealiased): 
    """Dealiasing data for 3D solvers.""" 
    cdef Py_ssize_t ik, i0, i1, i2, nk, n0, n1, n2 

    nk = data.shape[0] 
    n0 = data.shape[1] 
    n1 = data.shape[2] 
    n2 = data.shape[3] 

    for ik in range(nk): 
     for i0 in range(n0): 
      for i1 in range(n1): 
       for i2 in range(n2): 
        if where_dealiased[i0, i1, i2]: 
         data[ik, i0, i1, i2] = 0. 

在這裏,我需要爲一維,二維和三維的情況下三種功能。是否有一個好的方法來編寫一個可以完成所有(合理)維度的工作的函數?

PS:在這裏,我試過使用記憶體,但我不確定這是正確的方法來做到這一點。令我感到驚訝的是,在cython -a命令生成的帶註釋的html中,行if where_dealiased[i0]: data[ik, i0] = 0.不是白色。有什麼不對?

回答

2

我要說的第一件事是,有理由想要保留3個函數,並且使用更通用的函數,您可能會錯過cython編譯器和c編譯器的優化。

做一個包裝這3個函數的函數是非常可行的,它只需將兩個數組作爲python對象,檢查形狀並調用相關的其他函數。

但如果是要嘗試這個然後我會嘗試只是寫了最高層面的功能,然後用低維數組使用new axis符號重塑他們作爲高維數組:

cdef np.uint8_t [:] a1d = np.zeros((256,), np.uint8) # 1d 
cdef np.uint8_t [:, :] a2d = a1d[None, :]    # 2d 
cdef np.uint8_t [:, :, :] a3d = a1d[None, None, :] # 3d 
a2d[0, 100] = 42 
a3d[0, 0, 200] = 108 
print(a1d[100], a1d[200]) 
# (42, 108) 

cdef np.uint8_t [:, :] data2d = np.zeros((128, 256), np.uint8) #2d 
cdef np.uint8_t [:, :, :, :] data4d = data2d[None, None, :, :] #4d 
data4d[0, 0, 42, 108] = 64 
print(data2d[42, 108]) 
# 64 

正如您所看到的,內存視圖可以轉換爲更高維度,並可用於修改原始數據。您可能仍然希望編寫一個包裝函數,它在將新視圖傳遞給最高維函數之前執行這些技巧。我懷疑這個竅門在你的情況下會工作得很好,但是你必須四處遊玩,以便知道它是否會按照你的想法來處理你的數據。

隨着你的PS :,有一個非常簡單的解釋。 '額外代碼'是生成索引錯誤,類型錯誤的代碼,它允許您使用[-1]從數組的末尾索引而不是開始(繞回)。 您可以禁用這些額外的蟒蛇的功能和它降低到C陣列功能通過使用compiler directives,例如以消除整個文件這種額外的代碼,你可以包括在文件的開頭註釋:

# cython: boundscheck=False, wraparound=False, nonecheck=False 

編譯器指令也可以使用裝飾器在功能級應用。該文件解釋說。

1

可以在一般的方式讀取使用numpy.ndindex()strided屬性np.ndarray對象的,扁平陣列,使得所述位置由下式確定:

indices[0]*strides[0] + indices[1]*strides[1] + ... + indices[n]*strides[n] 

這是很容易實現的操作的方式(strides*indices).sum(),當strides是一維數組。下面的代碼說明了如何構造一個工作示例:

#cython profile=True 
#blacython wraparound=False 
#blacython boundscheck=False 
#blacython nonecheck=False 
#blacython cdivision=True 
cimport numpy as np 
import numpy as np 

def readNDArray(x): 
    if not isinstance(x, np.ndarray): 
     raise ValueError('x must be a valid np.ndarray object') 
    if x.itemsize != 8: 
     raise ValueError('x.dtype must be float64') 
    cdef np.ndarray[double, ndim=1] v # view of x 
    cdef np.ndarray[int, ndim=1] strides 
    cdef int pos 

    shape = list(x.shape) 
    strides = np.array([s//x.itemsize for s in x.strides], dtype=np.int32) 
    v = x.ravel() 
    for indices in np.ndindex(*shape): 
     pos = (strides*indices).sum() 
     v[pos] = 2. 
    return np.reshape(v, newshape=shape) 

該算法不會複製原始數組,如果它是C連續

def main(): 
    # case 1 
    x = np.array(np.random.random((3,4,5,6)), order='F') 
    y = readNDArray(x) 
    print(np.may_share_memory(x, y)) 
    # case 2 
    x = np.array(np.random.random((3,4,5,6)), order='C') 
    y = readNDArray(x) 
    print np.may_share_memory(x, y) 
    return 0 

結果:

False 
True