2010-08-19 24 views
10

幾年前,有人posted活動狀態食譜作比較,三個python/NumPy函數;其中每個接受相同的參數並返回相同的結果,距離矩陣爲什麼在這裏循環跳動索引?

其中兩個是從公開資料中獲得的;他們都是 - 或者他們似乎是我慣用的numpy代碼。創建距離矩陣所需的重複計算由numpy的優雅索引語法驅動。這裏是其中的一個:

from numpy.matlib import repmat, repeat 

def calcDistanceMatrixFastEuclidean(points): 
    numPoints = len(points) 
    distMat = sqrt(sum((repmat(points, numPoints, 1) - 
      repeat(points, numPoints, axis=0))**2, axis=1)) 
    return distMat.reshape((numPoints,numPoints)) 

使用一個循環(這顯然是一個很大的循環考慮到的只是1000的2D點的距離矩陣,有一萬個條目)第三創建距離矩陣。乍一看,這個函數在我看來像我在學習NumPy時編寫的代碼,我會先編寫NumPy代碼,然後逐行翻譯它。

活動狀態帖子發佈幾個月後,比較三者的性能測試結果在NumPy郵件列表上發佈並在thread中討論。

與事實上的循環功能顯著跑贏另外兩個:在線程

from numpy import mat, zeros, newaxis 

def calcDistanceMatrixFastEuclidean2(nDimPoints): 
    nDimPoints = array(nDimPoints) 
    n,m = nDimPoints.shape 
    delta = zeros((n,n),'d') 
    for d in xrange(m): 
    data = nDimPoints[:,d] 
    delta += (data - data[:,newaxis])**2 
    return sqrt(delta) 

一位與會者(凱爾·Mierle)提供一個理由,這可能是真實的:

我懷疑這會更快的原因是 它具有更好的地方性,完全完成一個相對較小的工作集上的計算,然後再轉到下一個工作集之一。一行 必須重複將可能較大的MxN陣列拉入處理器。

通過這張海報自己的帳戶,他的評論只是一個懷疑,似乎並沒有進一步討論。

有關如何解釋這些結果的其他想法?

特別是,有沒有一個有用的規則 - 關於什麼時候循環和何時索引 - 可以從這個例子中提取作爲編寫numpy代碼的指導?

對於那些不熟悉NumPy的人,或者沒有看過代碼的人,這種比較不是基於邊緣案例 - 如果是的話,這對我來說肯定不會那麼有趣。相反,這種比較涉及在矩陣計算中執行共同任務的功能(即,創建給定兩個前件的結果數組)。而且,每個函數都是由最常見的numpy內建插件組成的。

回答

11

TL; DR上面的第二個代碼僅在點的維數上循環(對於3D點,通過for循環的次數爲3次),因此循環並不多。上面第二個代碼中真正的加速是,它更好地利用了Numpy的力量,以避免在找到點之間的差異時創建一些額外的矩陣。這減少了使用的內存和計算量。

更長解釋 我認爲calcDistanceMatrixFastEuclidean2函數正在欺騙你的循環或許。它僅循環點的維數。對於1D點,循環只執行一次,對於2D,兩次,對於3D,則執行三次。這實際上並沒有太多循環。

讓我們分析一下代碼,看看爲什麼這個代碼比另一個更快。 calcDistanceMatrixFastEuclidean我會打電話fast1calcDistanceMatrixFastEuclidean2fast2

fast1是基於Matlab的做事方式,如repmap函數所證明的那樣。在這種情況下,repmap函數會創建一個數組,它只是原來的數據一遍又一遍地重複。但是,如果您查看該函數的代碼,則效率非常低。它使用許多Numpy功能(3 reshape s和2 repeat s)來執行此操作。 repeat函數也用於創建一個包含原始數據的數組,每個數據項重複多次。如果我們的輸入數據是[1,2,3],那麼我們從[1,1,1,2,2,2,3,3,3]減去[1,2,3,1,2,3,1,2,3]。 Numpy必須在運行Numpy的C代碼之間創建大量額外的矩陣,而這些代碼本可以避免。

fast2使用更多的Numpy的繁重工作,而不會在Numpy調用之間創建儘可能多的矩陣。 fast2循環通過點的每個維度,進行減法並保持每個維度之間的平方差的總計。只有最後纔是平方根。到目前爲止,這可能聽起來不如fast1那樣有效,但fast2通過使用Numpy的索引避免了做repmat的東西。爲簡單起見,我們來看一維情況。 fast2製作數據的一維數組,並從數據的2D(N×1)數組中減去它。這將創建每個點與所有其他點之間的差異矩陣,而不必使用repmatrepeat,從而繞過創建大量額外數組。這是真正的速度差異在我看來。 fast1在矩陣之間創建了許多額外的內容(並且它們的計算開銷很大)以找到點之間的差異,而fast2更好地利用了Numpy的力量來避免這些差異。

順便說一句,這裏是一個有點快的fast2版本:

def calcDistanceMatrixFastEuclidean3(nDimPoints): 
    nDimPoints = array(nDimPoints) 
    n,m = nDimPoints.shape 
    data = nDimPoints[:,0] 
    delta = (data - data[:,newaxis])**2 
    for d in xrange(1,m): 
    data = nDimPoints[:,d] 
    delta += (data - data[:,newaxis])**2 
    return sqrt(delta) 

不同的是,我們不再產生增量的零矩陣。

+0

非常有幫助,謝謝。從我+1。 – doug 2010-08-19 07:31:20

1

dis的樂趣:

dis.dis(calcDistanceMatrixFastEuclidean)

2   0 LOAD_GLOBAL    0 (len) 
       3 LOAD_FAST    0 (points) 
       6 CALL_FUNCTION   1 
       9 STORE_FAST    1 (numPoints) 

    3   12 LOAD_GLOBAL    1 (sqrt) 
      15 LOAD_GLOBAL    2 (sum) 
      18 LOAD_GLOBAL    3 (repmat) 
      21 LOAD_FAST    0 (points) 
      24 LOAD_FAST    1 (numPoints) 
      27 LOAD_CONST    1 (1) 
      30 CALL_FUNCTION   3 

    4   33 LOAD_GLOBAL    4 (repeat) 
      36 LOAD_FAST    0 (points) 
      39 LOAD_FAST    1 (numPoints) 
      42 LOAD_CONST    2 ('axis') 
      45 LOAD_CONST    3 (0) 
      48 CALL_FUNCTION   258 
      51 BINARY_SUBTRACT 
      52 LOAD_CONST    4 (2) 
      55 BINARY_POWER 
      56 LOAD_CONST    2 ('axis') 
      59 LOAD_CONST    1 (1) 
      62 CALL_FUNCTION   257 
      65 CALL_FUNCTION   1 
      68 STORE_FAST    2 (distMat) 

    5   71 LOAD_FAST    2 (distMat) 
      74 LOAD_ATTR    5 (reshape) 
      77 LOAD_FAST    1 (numPoints) 
      80 LOAD_FAST    1 (numPoints) 
      83 BUILD_TUPLE    2 
      86 CALL_FUNCTION   1 
      89 RETURN_VALUE 

dis.dis(calcDistanceMatrixFastEuclidean2)

2   0 LOAD_GLOBAL    0 (array) 
       3 LOAD_FAST    0 (nDimPoints) 
       6 CALL_FUNCTION   1 
       9 STORE_FAST    0 (nDimPoints) 

    3   12 LOAD_FAST    0 (nDimPoints) 
      15 LOAD_ATTR    1 (shape) 
      18 UNPACK_SEQUENCE   2 
      21 STORE_FAST    1 (n) 
      24 STORE_FAST    2 (m) 

    4   27 LOAD_GLOBAL    2 (zeros) 
      30 LOAD_FAST    1 (n) 
      33 LOAD_FAST    1 (n) 
      36 BUILD_TUPLE    2 
      39 LOAD_CONST    1 ('d') 
      42 CALL_FUNCTION   2 
      45 STORE_FAST    3 (delta) 

    5   48 SETUP_LOOP    76 (to 127) 
      51 LOAD_GLOBAL    3 (xrange) 
      54 LOAD_FAST    2 (m) 
      57 CALL_FUNCTION   1 
      60 GET_ITER 
     >> 61 FOR_ITER    62 (to 126) 
      64 STORE_FAST    4 (d) 

    6   67 LOAD_FAST    0 (nDimPoints) 
      70 LOAD_CONST    0 (None) 
      73 LOAD_CONST    0 (None) 
      76 BUILD_SLICE    2 
      79 LOAD_FAST    4 (d) 
      82 BUILD_TUPLE    2 
      85 BINARY_SUBSCR 
      86 STORE_FAST    5 (data) 

    7   89 LOAD_FAST    3 (delta) 
      92 LOAD_FAST    5 (data) 
      95 LOAD_FAST    5 (data) 
      98 LOAD_CONST    0 (None) 
      101 LOAD_CONST    0 (None) 
      104 BUILD_SLICE    2 
      107 LOAD_GLOBAL    4 (newaxis) 
      110 BUILD_TUPLE    2 
      113 BINARY_SUBSCR 
      114 BINARY_SUBTRACT 
      115 LOAD_CONST    2 (2) 
      118 BINARY_POWER 
      119 INPLACE_ADD 
      120 STORE_FAST    3 (delta) 
      123 JUMP_ABSOLUTE   61 
     >> 126 POP_BLOCK 

    8  >> 127 LOAD_GLOBAL    5 (sqrt) 
      130 LOAD_FAST    3 (delta) 
      133 CALL_FUNCTION   1 
      136 RETURN_VALUE 

我不是dis的專家,但好像你不得不看更多在f第一次打電話告訴他們爲什麼需要一段時間。還有一個使用Python的性能分析工具,cProfile

+1

如果您使用[cProfile](http://docs.python.org/library/profile.html#instant-user-s-manual),我建議使用[RunSnakeRun](http:// www。 vrplumber.com/programming/runsnakerun/)查看結果。 – detly 2010-08-19 04:25:17

+0

我注意到,Python優化的技巧似乎通常是讓Python解釋器儘可能少地執行Python指令。 – Omnifarious 2011-02-26 03:32:17

相關問題