2017-05-15 41 views
5

假設我有一堆陣列,包括xy,我想檢查它們是否相等。一般來說,我可以使用np.all(x == y)(禁止一些我現在忽略的愚蠢的角落案例)。檢查兩個numpy陣列是否相同

但是,這會評估通常不需要的整個陣列(x == y)。我的陣列非常大,而且我有很多它們,並且兩個陣列相等的概率很小,所以很可能我只需要評估(x == y)的一小部分,然後all函數可能會返回假,所以這不是我的最佳解決方案。

我已經使用內置all功能嘗試,結合itertools.izipall(val1==val2 for val1,val2 in itertools.izip(x, y))

不過,這似乎只是在情況下慢得多兩個陣列相等,即總體而言,它使用的STIL不值得超過np.all。我認爲是因爲內建的all的通用性。而np.all不適用於發電機。

有沒有辦法以更快捷的方式做我想做的事?

我知道這個問題類似於以前提出的問題(例如Comparing two numpy arrays for equality, element-wise),但他們並未特別提及提前終止的情況。

+0

你看這個功能:https://docs.scipy.org/doc/numpy-1.12.0/reference/generated/numpy.array_equal.html –

+0

@Thomas :該函數在內部調用'np.all',所以它是沒用的。 (我確實希望有一個專門用於此目的的功能來做短路,但唉,它不會。) – acdr

+1

嗯,這是一個恥辱。我猜,一個numpy內部的函數是你唯一的機會,因爲在numpy之外的任何循環幾乎肯定會變慢。你有沒有考慮直接聯繫開發者? –

回答

4

在此之前在numpy的實現本身就可以編寫自己的功能和JIT編譯它:

import numpy as np 
import numba as nb 


@nb.jit(nopython=True) 
def arrays_equal(a, b): 
    if a.shape != b.shape: 
     return False 
    for ai, bi in zip(a.flat, b.flat): 
     if ai != bi: 
      return False 
    return True 


a = np.random.rand(10, 20, 30) 
b = np.random.rand(10, 20, 30) 


%timeit np.all(a==b) # 100000 loops, best of 3: 9.82 µs per loop 
%timeit arrays_equal(a, a) # 100000 loops, best of 3: 9.89 µs per loop 
%timeit arrays_equal(a, b) # 100000 loops, best of 3: 691 ns per loop 

最壞情況下的性能(數組相等),相當於np.all以及在提早停止編譯函數的情況下,可能會大大超過np.all

+0

我喜歡它,但是對於我的測試數組,如果它們相等,它仍然比'np.all(arr1 == arr2)'長1.6倍。 (作爲參考,'arr1 = np.ones((1000000,),dtype = bool)','arr2 = np.ones((1000000,),dtype = bool)','arr2 [100000] = False')。 (請確保將timeit上的數字降低到1000.) – acdr

+0

@acdr當我使用數組時,'np.all'需要1.8 ms,'arrays_equal'需要183μs。如果我比較'arr1'對自己,大約需要1.8毫秒。也許這種差異是由我們的系統造成的?我有Python 3.5.2,numpy 1.12.1和numba 0.27.0。 – kazemakase

+0

可能。一般來說,我運行的是比以前更老的東西:Python 2.7.10.2,numpy 1.9.1,numba 0.20.0 – acdr

1

numpy page on github上顯然正在討論在陣列比較中添加短路邏輯,因此可能會在未來版本的numpy中提供。

0

您可以迭代數組中的所有元素並檢查它們是否相等。 如果數組最有可能不相等,它將比.all函數返回更快。 事情是這樣的:有numba

import numpy as np 

a = np.array([1, 2, 3]) 
b = np.array([1, 3, 4]) 

areEqual = True 

for x in range(0, a.size-1): 
     if a[x] != b[x]: 
       areEqual = False 
       break 
     else: 
       print "a[x] is equal to b[x]\n" 

if areEqual: 
     print "The tables are equal\n" 
else: 
     print "The tables are not equal\n" 
+0

這實際上是'all(val1 == val2 for val1,val2 in itertools.izip(x,y))':它通過'x'和'y'循環,返回'val1'和'val2'對,檢查它們是否相同,並將結果傳遞給'all',只要它找到一個不相等的對就會返回'False'。 – acdr

+0

哦,我明白了,我認爲它會經歷陣列的所有元素。 – imoutidi

+0

幸運的是,與'np.all'不同,內建'all'確實可以做斷路。 :) – acdr

0

也許有人理解底層的數據結構可以優化或解釋它是否可靠/安全/良好的做法,但它似乎工作。

np.all(a==b) 
Out[]: True 

memoryview(a.data)==memoryview(b.data) 
Out[]: True 

%timeit np.all(a==b) 
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached. 
100000 loops, best of 3: 6.2 µs per loop 

%timeit memoryview(a.data)==memoryview(b.data) 
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached. 
100000 loops, best of 3: 1.85 µs per loop 

如果我理解這個正確,ndarray.data創建一個指針到數據緩衝器和memoryview創建本地蟒類型,可以是短路移出緩衝器。

我想。

編輯:進一步的測試表明它可能不會像所顯示的那麼大的時間改進。此前a=b=np.eye(5)

a=np.random.randint(0,10,(100,100)) 

b=a.copy() 

%timeit np.all(a==b) 
The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached. 
10000 loops, best of 3: 17.7 µs per loop 

%timeit memoryview(a.data)==memoryview(b.data) 
10000 loops, best of 3: 30.1 µs per loop 

np.all(a==b) 
Out[]: True 

memoryview(a.data)==memoryview(b.data) 
Out[]: True 
+0

這是不是僅僅測試兩個數組是否實際上是同一個對象的不同名稱,而不是兩個具有相同值的不同對象? – acdr

+0

據我所知,還不夠。使用上面的'.copy()'進行測試,然後以相同的方式依次處理上面的兩個隨機數組。 –

0

嗯,我知道這就是窮人的答案,但它似乎沒有此沒有簡單的方法。 Numpy Creators應該修復它。我建議:

def compare(a, b): 
    if len(a) > 0 and not np.array_equal(a[0], b[0]): 
     return False 
    if len(a) > 15 and not np.array_equal(a[:15], b[:15]): 
     return False 
    if len(a) > 200 and not np.array_equal(a[:200], b[:200]): 
     return False 
    return np.array_equal(a, b) 

:)

+0

那麼,爲什麼你對已經成功解答了2年的問題提供了一個「不好的答案」? –

+0

因爲沒有人說,它不能用numpy完成,並且問題仍然是開放的我認爲 –

+0

[此答案](https://stackoverflow.com/a/43975611/501011)已被接受並使用numpy –