2013-02-07 52 views
2

如何提高此功能的速度?嵌套循環爲numpy convolve

def foo(mri_data, radius): 

    mask = mri_data.copy() 

    ny = len(mri_data[0,:]) 
    nx = len(mri_data[:]) 

    for y in xrange(0, ny): 
     for x in xrange(0, nx): 
      if (mri_data[x-radius:x+radius,y-radius:y+radius] != 1.0).all(): 
       mask[x,y] = 0.0      
    return mask.copy() 

它採用numpy陣列形式的圖像切片。遍歷每個像素並測試該像素周圍的邊界框。如果框中的值不等於1,則我們通過將其設置爲0來丟棄該像素。

我被告知我可以使用numpy.convolve,但我不確定這是如何涉及的。

編輯:圖像值在二進制範圍內,所以最低值爲0.0,最大值爲1.0。數值在例如0.767之間。

+0

小心邊緣。如果你有例如'radius = 3'和'mri_data = np.arange(8)'那麼你的第一個窗口是'mri_data [-3:3]',它返回一個空數組... – Jaime

回答

3

您可以濫用卷積的情況之一。我不會用它,但是界限,否則乏味......

from scipy.ndimage import convolve 

not_one = (mri_data != 1.0) # are you sure you want to compare with float like that?! 

conv = convolve(not_one, np.ones((2*radius, 2*radius))) 
all_not_one = (conv == (2*radius)**2) 

mask[all_not_one] = 0 

應該做的確實同樣的事情...

3

你在做什麼叫做binary_dilation,但有一個小錯誤代碼。特別是當x,y小於半徑時,你會得到負指數。這些負數是使用numpy索引規則解釋的,這不是您想要的more on indexing here,因此在圖像的兩個邊緣給出了錯誤的結果。

這是一些使用二進制擴展來完成相同的事情,並修復上述錯誤的代碼。

import numpy as np 
from scipy.ndimage import binary_dilation 

def foo(mri_data, radius): 
    structure = np.ones((2*radius, 2*radius)) 
    # I set the origin here to match your code 
    mask = binary_dilation(mri_data == 1, structure, origin=-1) 
    return np.where(mask, mri_data, 0) 
+0

感謝Bi Rico。我不知道我做的這種門檻已經是過濾器了。我將不得不嘗試你的解決方案。 –