2013-01-23 45 views
13

編輯我已經保持了我下面面臨的更復雜的問題,但我的問題與np.take可以更好地總結如下。假設您有一個形狀爲(planes, rows)的陣列img,以及形狀爲(planes, 256)的另一個陣列lut,並且您想要使用它們來創建形狀爲(planes, rows)的新陣列out,其中out[p,j] = lut[p, img[p, j]]。這可以用花哨的索引來實現如下:使用numpy.take更快的花式索引

In [4]: %timeit lut[np.arange(planes).reshape(-1, 1), img] 
1000 loops, best of 3: 471 us per loop 

但是,如果不是看中了索引,您使用採取並在planes事情蟒蛇循環可以極大地加快了:

In [6]: %timeit for _ in (lut[j].take(img[j]) for j in xrange(planes)) : pass 
10000 loops, best of 3: 59 us per loop 

lutimg是以某種方式重新排列,從而具有整體操作發生沒有蟒迴路,但使用numpy.take(或另一種方法),而不是常規的花式索引,以保持速度優勢?


原來的問題 我有一組,我想在圖像上使用查找表(LUT)的。保存LUT的陣列形狀爲(planes, 256, n),圖像形狀爲(planes, rows, cols)。兩者都是dtype = 'uint8',與LUT的256軸相匹配。這個想法是通過LUT的p平面中的每個n LUT運行圖像的第012平面。

如果我lutimg如下:

planes, rows, cols, n = 3, 4000, 4000, 4 
lut = np.random.randint(-2**31, 2**31 - 1, 
         size=(planes * 256 * n // 4,)).view('uint8') 
lut = lut.reshape(planes, 256, n) 
img = np.random.randint(-2**31, 2**31 - 1, 
        size=(planes * rows * cols // 4,)).view('uint8') 
img = img.reshape(planes, rows, cols) 

我可以用花哨的索引這樣

out = lut[np.arange(planes).reshape(-1, 1, 1), img] 

這給了我形(planes, rows, cols, n)的陣列,其中out[i, :, :, j]後實現我是什麼持有i-第img號飛機貫穿i LUT的第LUT平面......

一切都很好,除了這一點:

In [2]: %timeit lut[np.arange(planes).reshape(-1, 1, 1), img] 
1 loops, best of 3: 5.65 s per loop 

這是完全不能接受的,尤其是因爲我擁有所有使用np.take以下不那麼好看的替代品的速度遠遠超過運行:

  1. 單個平面上的單個LUT運行速度快x70:

    In [2]: %timeit np.take(lut[0, :, 0], img[0]) 
    10 loops, best of 3: 78.5 ms per loop 
    
  2. 通過所有的期望的組合中運行的蟒循環完成幾乎5233更快:

    In [2]: %timeit for _ in (np.take(lut[j, :, k], img[j]) for j in xrange(planes) for k in xrange(n)) : pass 
    1 loops, best of 3: 947 ms per loop 
    
  3. 即使運行平面的所有組合中的LUT和圖像,然後丟棄該planes**2 - planes不想要的低於花式索引更快:

    In [2]: %timeit np.take(lut, img, axis=1)[np.arange(planes), np.arange(planes)] 
    1 loops, best of 3: 3.79 s per loop 
    
  4. 而且我已經能夠拿出最快的組合,擁有飛機蟒蛇循環迭代,並完成X13更快:

    In [2]: %timeit for _ in (np.take(lut[j], img[j], axis=0) for j in xrange(planes)) : pass 
    1 loops, best of 3: 434 ms per loop 
    

當然問題是,如果沒有辦法做到這一點np.take沒有任何python循環?理想的情況是任何整形或調整大小,需要對LUT,而不是圖像應該發生,但我很開放,無論你的人能拿出...

+0

什麼是'在你的代碼片段bkpt' - 沒有必要解釋,我只是想提醒你的情況下,它的一個錯字 - 我猜它應該是'lut'? –

+0

......不應該把這整行讀成'lut = lut.reshape(planes,256,4)',所以'4'在最後一個dim? –

+0

@TheodrosZelleke謝謝你抓住那些!我的'lut'實際上是一個**斷點表**,所以在我的代碼中它被稱爲'bkpt',並且當我翻譯它時,我錯過了這個問題。 – Jaime

回答

5

我不得不說,我真的很喜歡你的問題的拳頭。如果不重新安排LUTIMG以下解決方案工作:

%timeit a=np.take(lut, img, axis=1) 
# 1 loops, best of 3: 1.93s per loop 

但是從結果你要查詢的對角線:一[0,0],[1,1],一個[2,2]。得到你想要的東西。我一直試圖找到一種方法,只爲對角元素做索引,但仍然沒有管理。

這裏有一些方法來重新排列你的LUTIMG:爲第三 下面的作品如果IMG索引是從0到255,爲第1面,256-511爲第二平面,512-767飛機,但會阻止您使用'uint8',它可以是一個大問題...:

lut2 = lut.reshape(-1,4) 
%timeit np.take(lut2,img,axis=0) 
# 1 loops, best of 3: 716 ms per loop 
# or 
%timeit np.take(lut2, img.flatten(), axis=0).reshape(3,4000,4000,4) 
# 1 loops, best of 3: 709 ms per loop 
在我的機器

您的解決方案仍然是最好的選擇,而且非常充足,因爲你只需要對角的評價,即基準面1,基準面1,平面2,平面2和plane3-plane3:

%timeit for _ in (np.take(lut[j], img[j], axis=0) for j in xrange(planes)) : pass 
# 1 loops, best of 3: 677 ms per loop 

我希望這可以給你一個更好的解決方案的一些見解。這將是很好,尋找更多的選擇與flatten(),和類似的方法作爲np.apply_over_axes()np.apply_along_axis(),似乎是有希望的。

我用下面這段代碼來生成數據:

import numpy as np 
num = 4000 
planes, rows, cols, n = 3, num, num, 4 
lut = np.random.randint(-2**31, 2**31-1,size=(planes*256*n//4,)).view('uint8') 
lut = lut.reshape(planes, 256, n) 
img = np.random.randint(-2**31, 2**31-1,size=(planes*rows*cols//4,)).view('uint8') 
img = img.reshape(planes, rows, cols)