2014-02-17 147 views
3

我需要幫助矢量化此代碼。現在,N = 100,運行需要一分鐘左右的時間。我想加快速度。我已經做了這樣的雙循環,但從來沒有一個3D循環,我有困難。3D距離矢量化

import numpy as np 
N = 100 
n = 12 
r = np.sqrt(2) 

x = np.arange(-N,N+1) 
y = np.arange(-N,N+1) 
z = np.arange(-N,N+1) 

C = 0 

for i in x: 
    for j in y: 
     for k in z: 
      if (i+j+k)%2==0 and (i*i+j*j+k*k!=0): 
       p = np.sqrt(i*i+j*j+k*k) 
       p = p/r 
       q = (1/p)**n 
       C += q 

print '\n' 
print C 
+0

「我有困難」沒有幫助 - 您是否收到錯誤(提供追溯)?意想不到的產出(提供投入,預期產出,實際產出)? – jonrsharpe

+1

@jonrsharpe:你看過這個問題嗎?代碼是正確的,但速度很慢,所以OP想把一些工作推到快速的numpy庫中,而不是在慢Python中執行循環。 –

+0

是的,代碼是正確的,我只是想要它的矢量化。我想學習如何使它儘可能高效 – NightHallow

回答

2

meshgrid/where/indexing解決方案已經非常快。我做了大約65%的速度。這不是太多,但我無論如何一步一步解釋:

對於我來說,解決這個問題最簡單的方法是將網格中的所有3D向量作爲一個大的2D 3 x M數組中的列。 meshgrid是創建所有組合的正確工具(請注意,3D網格網格需要numpy version> = 1.7),並且vstack + reshape將數據轉換爲所需的形式。例如:

>>> np.vstack(np.meshgrid(*[np.arange(0, 2)]*3)).reshape(3,-1) 
array([[0, 0, 1, 1, 0, 0, 1, 1], 
     [0, 0, 0, 0, 1, 1, 1, 1], 
     [0, 1, 0, 1, 0, 1, 0, 1]]) 

每列是一個3D矢量。這八個矢量中的每一個表示一個1x1x1立方體的一個角(在所有維中具有步長1和長度1的3D網格)。

我們稱這個數組爲vectors(它包含表示網格中所有點的所有3D向量)。然後,準備一個bool掩模用於選擇那些滿足您MOD2標準載體:

mod2bool = np.sum(vectors, axis=0) % 2 == 0 

np.sum(vectors, axis=0)創建1 x M陣列包含用於每個列向量的元素之和。因此,mod2bool是一個1 x M數組,每個列向量具有一個bool值。現在用這個布爾面膜:

vectorsubset = vectors[:,mod2bool] 

這將選擇所有行(:),並使用布爾索引用於過濾列,都是在numpy的快速操作。計算剩餘向量的長度,使用本機numpy的方法:

lengths = np.sqrt(np.sum(vectorsubset**2, axis=0)) 

這是相當快的 - 但是,scipy.stats.ssbottleneck.ss可以速度甚至比這個執行平方和運算。由零

with np.errstate(divide='ignore'): 
     p = (r/lengths)**n 

這涉及有限數量的劃分,從而導致Inf S中的輸出數組中:

使用你的指令變換長度。這完全沒問題。我們使用numpy的errstate上下文管理器來確保這些零分區不會拋出異常或運行時警告。

現在總結有限元(忽略的INF),並返回的總和:

return np.sum(p[np.isfinite(p)]) 

我在下面兩次實現此方法。一旦完全像剛剛解釋的那樣,並且一旦涉及瓶頸的ssnansum函數。我還添加了用於比較的方法,以及跳過np.where((x*x+y*y+z*z)!=0)索引的方法的修改版本,而是創建了Inf s,最後總結了isfinite的方法。

import sys 
import numpy as np 
import bottleneck as bn 

N = 100 
n = 12 
r = np.sqrt(2) 


x,y,z = np.meshgrid(*[np.arange(-N, N+1)]*3) 
gridvectors = np.vstack((x,y,z)).reshape(3, -1) 


def measure_time(func): 
    import time 
    def modified_func(*args, **kwargs): 
     t0 = time.time() 
     result = func(*args, **kwargs) 
     duration = time.time() - t0 
     print("%s duration: %.3f s" % (func.__name__, duration)) 
     return result 
    return modified_func 


@measure_time 
def method_columnvecs(vectors): 
    mod2bool = np.sum(vectors, axis=0) % 2 == 0 
    vectorsubset = vectors[:,mod2bool] 
    lengths = np.sqrt(np.sum(vectorsubset**2, axis=0)) 
    with np.errstate(divide='ignore'): 
     p = (r/lengths)**n 
    return np.sum(p[np.isfinite(p)]) 


@measure_time 
def method_columnvecs_opt(vectors): 
    # On my system, bn.nansum is even slightly faster than np.sum. 
    mod2bool = bn.nansum(vectors, axis=0) % 2 == 0 
    # Use ss from bottleneck or scipy.stats (axis=0 is default). 
    lengths = np.sqrt(bn.ss(vectors[:,mod2bool])) 
    with np.errstate(divide='ignore'): 
     p = (r/lengths)**n 
    return bn.nansum(p[np.isfinite(p)]) 


@measure_time 
def method_original(x,y,z): 
    ind = np.where((x+y+z)%2==0) 
    x = x[ind] 
    y = y[ind] 
    z = z[ind] 
    ind = np.where((x*x+y*y+z*z)!=0) 
    x = x[ind] 
    y = y[ind] 
    z = z[ind] 
    p=np.sqrt(x*x+y*y+z*z)/r 
    return np.sum((1/p)**n) 


@measure_time 
def method_original_finitesum(x,y,z): 
    ind = np.where((x+y+z)%2==0) 
    x = x[ind] 
    y = y[ind] 
    z = z[ind] 
    lengths = np.sqrt(x*x+y*y+z*z) 
    with np.errstate(divide='ignore'): 
     p = (r/lengths)**n 
    return np.sum(p[np.isfinite(p)]) 


print method_columnvecs(gridvectors) 
print method_columnvecs_opt(gridvectors) 
print method_original(x,y,z) 
print method_original_finitesum(x,y,z) 

這是輸出:

$ python test.py 
method_columnvecs duration: 1.295 s 
12.1318801965 
method_columnvecs_opt duration: 1.162 s 
12.1318801965 
method_original duration: 1.936 s 
12.1318801965 
method_original_finitesum duration: 1.714 s 
12.1318801965 

的所有方法產生相同的結果。在執行isfinite樣式總和時,您的方法會變得更快一些。我的方法更快,但我會說這是一個學術性質的練習,而不是一個重要的改進:-)

我還有一個問題:你是說,對於N = 3,計算應該產生一個12甚至你的也不會這樣做。以上所有方法對於N = 3產生12.1317530867。這是預期的嗎?

+0

我認爲我的意思是N = 1應該給12,我只用於測試用例,看看我的東西是否正確。非常感謝您的意見和幫助。我學到了很多東西,我甚至都沒有想過。謝謝! – NightHallow

2

感謝@Bill,我能夠得到這個工作。現在非常快。也許可以做得更好,特別是用兩個面具來擺脫我最初爲循環所用的兩個條件。

from __future__ import division 
    import numpy as np 

    N = 100 
    n = 12 
    r = np.sqrt(2) 

    x, y, z = np.meshgrid(*[np.arange(-N, N+1)]*3) 

    ind = np.where((x+y+z)%2==0) 
    x = x[ind] 
    y = y[ind] 
    z = z[ind] 
    ind = np.where((x*x+y*y+z*z)!=0) 
    x = x[ind] 
    y = y[ind] 
    z = z[ind] 

    p=np.sqrt(x*x+y*y+z*z)/r 

    ans = (1/p)**n 
    ans = np.sum(ans) 
    print 'ans' 
    print ans 
+0

我對這些條件'(x + y + z)%2 == 0'和' (x * x + y * y + z * z)!= 0'來自。任何物理意義?後者只是一個繞零工作的過濾器? numpy可以處理這個問題。 –

+0

Just FYI,'距離= distance.pdist(向量,'sqeuclidean')'可能顯着加快你的'np.sqrt(x * x + y * y + z * z)',所以你可能想在過濾之前這樣做。 –

+0

如果我把它放在過濾之前,那麼我想怎麼過濾呢?而且,這些條件來自我正在觀察的原子晶格的物理意義。它是固體物理學,這些是我用於格點的條件,格點設置爲只有原子,如果座標偶數相加,原點上沒有,因爲它是參考原子。 – NightHallow