2009-11-23 120 views
0

我試圖使用花式索引而不是循環來加速Numpy中的函數。據我所知,我已經正確實施了花哨的索引版本。問題是兩個函數(循環和花式索引)不會返回相同的結果。我不知道爲什麼。值得指出的是,如果使用較小的數組(如20 x 20 x 20),函數的返回值相同。爲什麼這兩個數學函數不會返回相同的結果?

下面我已經包含了重現錯誤所需的一切。如果函數返回相同的結果,那麼行find_maxdiff(data) - find_maxdiff_fancy(data)應返回一個滿零的數組。

from numpy import * 

def rms(data, axis=0): 
    return sqrt(mean(data ** 2, axis)) 

def find_maxdiff(data): 
    samples, channels, epochs = shape(data) 
    window_size = 50 
    maxdiff = zeros(epochs) 
    for epoch in xrange(epochs): 
     signal = rms(data[:, :, epoch], axis=1) 
     for t in xrange(window_size, alen(signal) - window_size): 
      amp_a = mean(signal[t-window_size:t], axis=0) 
      amp_b = mean(signal[t:t+window_size], axis=0) 
      the_diff = abs(amp_b - amp_a) 
      if the_diff > maxdiff[epoch]: 
       maxdiff[epoch] = the_diff 

    return maxdiff 

def find_maxdiff_fancy(data): 
    samples, channels, epochs = shape(data) 
    window_size = 50 
    maxdiff = zeros(epochs) 
    signal = rms(data, axis=1) 
    for t in xrange(window_size, alen(signal) - window_size): 
     amp_a = mean(signal[t-window_size:t], axis=0) 
     amp_b = mean(signal[t:t+window_size], axis=0) 
     the_diff = abs(amp_b - amp_a) 
     maxdiff[the_diff > maxdiff] = the_diff 

    return maxdiff 

data = random.random((600, 20, 100)) 
find_maxdiff(data) - find_maxdiff_fancy(data) 

data = random.random((20, 20, 20)) 
find_maxdiff(data) - find_maxdiff_fancy(data) 
+5

兩者之間有什麼不同?這不是典型的浮點精度問題,它吸引了很多人? – spender 2009-11-23 09:50:32

+0

20x20x20和600x20x100之間的值是否會出現問題?事情是否逐步出現錯誤,越來越多,或一次全部出錯? – AakashM 2009-11-23 10:00:20

+0

差異的程度相當大,只是浮點錯誤。 – pealco 2009-11-23 16:08:32

回答

3

問題是這樣的線:

maxdiff[the_diff > maxdiff] = the_diff 

左側只選擇maxdiff的一些元件,但是右側包含the_diff的所有元素。這應該工作,而不是:

replaceElements = the_diff > maxdiff 
maxdiff[replaceElements] = the_diff[replaceElements] 

或者乾脆:

maxdiff = maximum(maxdiff, the_diff) 

至於爲什麼20x20x20大小似乎工作:這是因爲你的窗口尺寸太大,所以沒有得到執行。

+0

感謝您幫助我更好地理解特別的分配是如何工作的。此外,我應該已經捕捉到爲什麼小陣列工作的愚蠢理由:)再次感謝。 – pealco 2009-11-23 15:58:03

0

首先,在現在看中你的信號是2D的,如果我理解正確的 - 所以我認爲這將是更清晰的索引它明確(例如amp_a =平均(信號[T-WINDOW_SIZE:T ,:],軸= 0),同樣與阿倫(信號) - 在這兩種情況下,這應該只是樣品,所以我認爲這將是更清楚使用

這是錯誤的,只要你實際上在做的t循環的東西 - 當samples < window_lenght就像在20x20x20的例子中一樣,這個循環從來沒有被執行過,只要這個循環被多次執行(即samples > 2 *window_length+1),那麼錯誤就會出現。不知道爲什麼 - 它們的確看起來和我相當。

相關問題