2013-07-08 39 views
5

我想解決如何加快使用numpy的Python函數。我從lineprofiler收到的輸出在下面,這表明絕大多數時間都用在了線路ind_y, ind_x = np.where(seg_image == i)上。加快numpy.where提取整數段?

seg_image是一個整數數組,它是分割圖像的結果,因此找到了像素,其中seg_image == i提取了特定的分割對象。我循環了很多這些對象(在下面的代碼中,我只是循環了5次測試,但實際上我會循環超過20,000次),而且它需要很長時間才能運行!

有沒有什麼方法可以加快np.where的呼叫?或者,可以加快倒數第二行(也佔用很大比例的時間)?

理想的解決方案是一次運行整個數組的代碼,而不是循環,但我不認爲這是可能的,因爲有一些我需要運行的函數的副作用(for例如,擴大分割的對象可能會使它與下一個區域「碰撞」,從而在以後出現不正確的結果)。

有沒有人有任何想法?

Line #  Hits   Time Per Hit % Time Line Contents 
============================================================== 
    5           def correct_hot(hot_image, seg_image): 
    6   1  239810 239810.0  2.3  new_hot = hot_image.copy() 
    7   1  572966 572966.0  5.5  sign = np.zeros_like(hot_image) + 1 
    8   1  67565 67565.0  0.6  sign[:,:] = 1 
    9   1  1257867 1257867.0  12.1  sign[hot_image > 0] = -1 
    10           
    11   1   150 150.0  0.0  s_elem = np.ones((3, 3)) 
    12           
    13            #for i in xrange(1,seg_image.max()+1): 
    14   6   57  9.5  0.0  for i in range(1,6): 
    15   5  6092775 1218555.0  58.5   ind_y, ind_x = np.where(seg_image == i) 
    16           
    17             # Get the average HOT value of the object (really simple!) 
    18   5   2408 481.6  0.0   obj_avg = hot_image[ind_y, ind_x].mean() 
    19           
    20   5   333  66.6  0.0   miny = np.min(ind_y) 
    21             
    22   5   162  32.4  0.0   minx = np.min(ind_x) 
    23             
    24           
    25   5   369  73.8  0.0   new_ind_x = ind_x - minx + 3 
    26   5   113  22.6  0.0   new_ind_y = ind_y - miny + 3 
    27           
    28   5   211  42.2  0.0   maxy = np.max(new_ind_y) 
    29   5   143  28.6  0.0   maxx = np.max(new_ind_x) 
    30           
    31             # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above 
    32   5   217  43.4  0.0   obj = np.zeros((maxy+7, maxx+7)) 
    33           
    34   5   158  31.6  0.0   obj[new_ind_y, new_ind_x] = 1 
    35           
    36   5   2482 496.4  0.0   dilated = ndimage.binary_dilation(obj, s_elem) 
    37   5   1370 274.0  0.0   border = mahotas.borders(dilated) 
    38           
    39   5   122  24.4  0.0   border = np.logical_and(border, dilated) 
    40           
    41   5   355  71.0  0.0   border_ind_y, border_ind_x = np.where(border == 1) 
    42   5   136  27.2  0.0   border_ind_y = border_ind_y + miny - 3 
    43   5   123  24.6  0.0   border_ind_x = border_ind_x + minx - 3 
    44           
    45   5   645 129.0  0.0   border_avg = hot_image[border_ind_y, border_ind_x].mean() 
    46           
    47   5  2167729 433545.8  20.8   new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg))) 
    48   5  10179 2035.8  0.1   print obj_avg, border_avg 
    49           
    50   1   4  4.0  0.0  return new_hot 

回答

4

編輯我已經離開了我原來的答覆底部備案,但其實我也看着你的代碼中午飯更詳細,我覺得用np.where是一個很大的錯誤:

In [63]: a = np.random.randint(100, size=(1000, 1000)) 

In [64]: %timeit a == 42 
1000 loops, best of 3: 950 us per loop 

In [65]: %timeit np.where(a == 42) 
100 loops, best of 3: 7.55 ms per loop 

你可以得到一個布爾數組(你可以使用索引)在1/8的時間內獲得點的實際座標!

當然有,你做的功能裁剪,但ndimage具有find_objects函數返回封閉片,顯得非常快:

In [66]: %timeit ndimage.find_objects(a) 
100 loops, best of 3: 11.5 ms per loop 

這將返回切片的元組的列表附上全部的對象,在50%以上的時間內找到單個對象的索引。

它可能不起作用開箱,我現在不能測試,但我會調整你的代碼轉換成類似以下內容:

def correct_hot_bis(hot_image, seg_image): 
    # Need this to not index out of bounds when computing border_avg 
    hot_image_padded = np.pad(hot_image, 3, mode='constant', 
           constant_values=0) 
    new_hot = hot_image.copy() 
    sign = np.ones_like(hot_image, dtype=np.int8) 
    sign[hot_image > 0] = -1 
    s_elem = np.ones((3, 3)) 

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)): 
     hot_image_view = hot_image[slice_] 
     seg_image_view = seg_image[slice_] 
     new_shape = tuple(dim+6 for dim in hot_image_view.shape) 
     new_slice = tuple(slice(dim.start, 
           dim.stop+6, 
           None) for dim in slice_) 
     indices = seg_image_view == j+1 

     obj_avg = hot_image_view[indices].mean() 

     obj = np.zeros(new_shape) 
     obj[3:-3, 3:-3][indices] = True 

     dilated = ndimage.binary_dilation(obj, s_elem) 
     border = mahotas.borders(dilated) 
     border &= dilated 

     border_avg = hot_image_padded[new_slice][border == 1].mean() 

     new_hot[slice_][indices] += (sign[slice_][indices] * 
            np.abs(obj_avg - border_avg)) 

    return new_hot 

你仍然需要弄清楚碰撞,但你可以通過計算同時使用np.unique爲基礎的方法在所有指數得到約2倍加速:

a = np.random.randint(100, size=(1000, 1000)) 

def get_pos(arr): 
    pos = [] 
    for j in xrange(100): 
     pos.append(np.where(arr == j)) 
    return pos 

def get_pos_bis(arr): 
    unq, flat_idx = np.unique(arr, return_inverse=True) 
    pos = np.argsort(flat_idx) 
    counts = np.bincount(flat_idx) 
    cum_counts = np.cumsum(counts) 
    multi_dim_idx = np.unravel_index(pos, arr.shape) 
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx)) 

In [33]: %timeit get_pos(a) 
1 loops, best of 3: 766 ms per loop 

In [34]: %timeit get_pos_bis(a) 
1 loops, best of 3: 388 ms per loop 

請注意,每個像素Ø bject以不同的順序返回,所以你不能簡單地比較兩個函數的返回值來評估等式。但他們都應該返回相同的。

+0

這是美好的,真棒和驚人的 - 謝謝你!第一次運行它時,我發現它實際上比我的原始代碼慢,但後來我修改了一些代碼,以便在小數組中完成所有工作(擴展,邊框等)而不是巨大的數組 - 通過修改new_shape的計算方式。現在我的速度已經大大提高了。在我正在使用的圖像之一上,舊版本花費了兩個半小時,新版本耗時11秒! – robintw

+0

糟糕!是的,它看起來像生成器表達式應該是'new_shape = tuple(dim + 6 for dim in hot_image_view.shape)',而不是'new_shape = tuple(dim + 6 for dim in hot_image.shape)''。那是你改變的嗎?請隨時編輯我的答案以反映工作代碼。 – Jaime

2

有一兩件事你可以做的時間相同的一點是挽救seg_image == i的結果,這樣你就不必計算它的兩倍。你正在計算它的行​​數,你可以添加seg_mask = seg_image == i然後重新使用這個結果(爲了分析的目的,分離出這個結果也是很好的)。

雖然還有其他一些小的事情可以做出一點表現,但根本問題是您使用的是O(M * N)算法,其中M是段的數量,N是你的形象的大小。從代碼中我不明白是否有更快的算法來完成同樣的事情,但這是我嘗試尋找加速的第一個位置。