meshgrid/where/indexing解決方案已經非常快。我做了大約65%的速度。這不是太多,但我無論如何一步一步解釋:
對於我來說,解決這個問題最簡單的方法是將網格中的所有3D向量作爲一個大的2D 3 x M
數組中的列。 meshgrid
是創建所有組合的正確工具(請注意,3D網格網格需要numpy version> = 1.7),並且vstack
+ reshape
將數據轉換爲所需的形式。例如:
>>> np.vstack(np.meshgrid(*[np.arange(0, 2)]*3)).reshape(3,-1)
array([[0, 0, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1]])
每列是一個3D矢量。這八個矢量中的每一個表示一個1x1x1
立方體的一個角(在所有維中具有步長1和長度1的3D網格)。
我們稱這個數組爲vectors
(它包含表示網格中所有點的所有3D向量)。然後,準備一個bool掩模用於選擇那些滿足您MOD2標準載體:
mod2bool = np.sum(vectors, axis=0) % 2 == 0
np.sum(vectors, axis=0)
創建1 x M
陣列包含用於每個列向量的元素之和。因此,mod2bool
是一個1 x M
數組,每個列向量具有一個bool值。現在用這個布爾面膜:
vectorsubset = vectors[:,mod2bool]
這將選擇所有行(:),並使用布爾索引用於過濾列,都是在numpy的快速操作。計算剩餘向量的長度,使用本機numpy的方法:
lengths = np.sqrt(np.sum(vectorsubset**2, axis=0))
這是相當快的 - 但是,scipy.stats.ss
和bottleneck.ss
可以速度甚至比這個執行平方和運算。由零
with np.errstate(divide='ignore'):
p = (r/lengths)**n
這涉及有限數量的劃分,從而導致Inf
S中的輸出數組中:
使用你的指令變換長度。這完全沒問題。我們使用numpy的errstate
上下文管理器來確保這些零分區不會拋出異常或運行時警告。
現在總結有限元(忽略的INF),並返回的總和:
return np.sum(p[np.isfinite(p)])
我在下面兩次實現此方法。一旦完全像剛剛解釋的那樣,並且一旦涉及瓶頸的ss
和nansum
函數。我還添加了用於比較的方法,以及跳過np.where((x*x+y*y+z*z)!=0)
索引的方法的修改版本,而是創建了Inf
s,最後總結了isfinite
的方法。
import sys
import numpy as np
import bottleneck as bn
N = 100
n = 12
r = np.sqrt(2)
x,y,z = np.meshgrid(*[np.arange(-N, N+1)]*3)
gridvectors = np.vstack((x,y,z)).reshape(3, -1)
def measure_time(func):
import time
def modified_func(*args, **kwargs):
t0 = time.time()
result = func(*args, **kwargs)
duration = time.time() - t0
print("%s duration: %.3f s" % (func.__name__, duration))
return result
return modified_func
@measure_time
def method_columnvecs(vectors):
mod2bool = np.sum(vectors, axis=0) % 2 == 0
vectorsubset = vectors[:,mod2bool]
lengths = np.sqrt(np.sum(vectorsubset**2, axis=0))
with np.errstate(divide='ignore'):
p = (r/lengths)**n
return np.sum(p[np.isfinite(p)])
@measure_time
def method_columnvecs_opt(vectors):
# On my system, bn.nansum is even slightly faster than np.sum.
mod2bool = bn.nansum(vectors, axis=0) % 2 == 0
# Use ss from bottleneck or scipy.stats (axis=0 is default).
lengths = np.sqrt(bn.ss(vectors[:,mod2bool]))
with np.errstate(divide='ignore'):
p = (r/lengths)**n
return bn.nansum(p[np.isfinite(p)])
@measure_time
def method_original(x,y,z):
ind = np.where((x+y+z)%2==0)
x = x[ind]
y = y[ind]
z = z[ind]
ind = np.where((x*x+y*y+z*z)!=0)
x = x[ind]
y = y[ind]
z = z[ind]
p=np.sqrt(x*x+y*y+z*z)/r
return np.sum((1/p)**n)
@measure_time
def method_original_finitesum(x,y,z):
ind = np.where((x+y+z)%2==0)
x = x[ind]
y = y[ind]
z = z[ind]
lengths = np.sqrt(x*x+y*y+z*z)
with np.errstate(divide='ignore'):
p = (r/lengths)**n
return np.sum(p[np.isfinite(p)])
print method_columnvecs(gridvectors)
print method_columnvecs_opt(gridvectors)
print method_original(x,y,z)
print method_original_finitesum(x,y,z)
這是輸出:
$ python test.py
method_columnvecs duration: 1.295 s
12.1318801965
method_columnvecs_opt duration: 1.162 s
12.1318801965
method_original duration: 1.936 s
12.1318801965
method_original_finitesum duration: 1.714 s
12.1318801965
的所有方法產生相同的結果。在執行isfinite
樣式總和時,您的方法會變得更快一些。我的方法更快,但我會說這是一個學術性質的練習,而不是一個重要的改進:-)
我還有一個問題:你是說,對於N = 3,計算應該產生一個12甚至你的也不會這樣做。以上所有方法對於N = 3產生12.1317530867。這是預期的嗎?
「我有困難」沒有幫助 - 您是否收到錯誤(提供追溯)?意想不到的產出(提供投入,預期產出,實際產出)? – jonrsharpe
@jonrsharpe:你看過這個問題嗎?代碼是正確的,但速度很慢,所以OP想把一些工作推到快速的numpy庫中,而不是在慢Python中執行循環。 –
是的,代碼是正確的,我只是想要它的矢量化。我想學習如何使它儘可能高效 – NightHallow