17

有人知道關於此行爲的文檔嗎?Numpy dot太聰明瞭關於對稱乘法

import numpy as np 
A = np.random.uniform(0,1,(10,5)) 
w = np.ones(5) 
Aw = A*w 
Sym1 = Aw.dot(Aw.T) 
Sym2 = (A*w).dot((A*w).T) 
diff = Sym1 - Sym2 

diff.max()是靠近機器精度非零,例如4.4E-16。

這(從0的差異)通常很好......在有限精度的世界中,我們不應該感到驚訝。

而且,我猜numpy的是聰明約對稱的產品,以節省拖鞋和確保對稱輸出...

但我對付混沌系統,而這個小差異很快變得明顯時調試 。所以我想知道到底發生了什麼。

+1

因爲你的代碼將會給從運行之間變化的輸出,請出示一個樣本輸出並更清楚地陳述什麼是不希望的輸出。 –

+0

您是否嘗試強制使用雙打('np.float64')? –

+0

@TomdeGeus怎麼樣?無論如何,請注意,我並不在意這種差異是否爲零。我只是想讓這種行爲(顯然來自numpy聰明)解釋。 – Patrick

回答

19

此行爲是介紹了NumPy的1.11.0,在pull request #6932變化的結果。用於1.11.0的release notes

此前,gemm BLAS操作用於所有矩陣產品。 現在,如果矩陣乘積在矩陣與其轉置之間,它將使用syrk BLAS操作來提高性能。這 優化已擴展到@,numpy.dot,numpy.inner和 numpy.matmul。

在變更爲PR,人們發現this comment

/* 
* Use syrk if we have a case of a matrix times its transpose. 
* Otherwise, use gemm for all other cases. 
*/ 

所以NumPy的正在爲一個矩陣乘以其轉置的情況下,直接檢查,並要求在不同的底層BLAS功能案件。正如@hpaulj在評論中指出的那樣,對於NumPy來說這樣的檢查是便宜的,因爲轉置的二維數組只是原始數組上的一個視圖,而且形狀和步幅都是反向的,所以只需檢查數組中的幾條元數據即可而不必比較實際的陣列數據)。

這是一個稍微簡單的案例,顯示了差異。請注意,對dot的參數之一使用.copy足以擊敗NumPy的特殊框架。

import numpy as np 
random = np.random.RandomState(12345) 
A = random.uniform(size=(10, 5)) 
Sym1 = A.dot(A.T) 
Sym2 = A.dot(A.T.copy()) 
print(abs(Sym1 - Sym2).max()) 

我想這個特殊外殼的一個優點,超出了加速的明顯的潛在,是你保證(我倒是希望,但在實踐中它會依賴於BLAS實現)在使用syrk時得到完全對稱的結果,而不是僅僅對稱於數值誤差的矩陣。作爲一個(當然不是很好)試驗對於這一點,我想:

import numpy as np 
random = np.random.RandomState(12345) 
A = random.uniform(size=(100, 50)) 
Sym1 = A.dot(A.T) 
Sym2 = A.dot(A.T.copy()) 
print("Sym1 symmetric: ", (Sym1 == Sym1.T).all()) 
print("Sym2 symmetric: ", (Sym2 == Sym2.T).all()) 

我的機器上結果:

Sym1 symmetric: True 
Sym2 symmetric: False 
+3

「轉置」是一個連續翻轉(F連續)並跨越的「視圖」。所以它可以通過比較幾個屬性來識別,而不需要查看任何數據(這會很昂貴)。 – hpaulj

+0

剛剛發現,至少在發行說明中記錄了此更改。答案已更新。 –

1

我懷疑這是關於中間浮點寄存器升級到80位精度。有點證實這個假設是,如果我們使用更少的花車,我們始終得到我們的結果0,ALA

A = np.random.uniform(0,1,(4,2)) 
w = np.ones(2) 
Aw = A*w 
Sym1 = Aw.dot(Aw.T) 
Sym2 = (A*w).dot((A*w).T) 
diff = Sym1 - Sym2 
# diff is all 0's (ymmv) 
+0

這是可能的。但我認爲你的實驗也可以用我的建議來解釋:numpy用不同的方式處理對稱產品。我們怎麼能確定地知道? – Patrick

+3

我們可以確定這不是原因,因爲python中間值永遠不會存儲80位精度,因爲它們存儲在內存中而不是寄存器。這裏重要的是'A * w'被計算兩次,而不是計算它作爲更大表達式 – Eric

+0

@Patrick的一部分。最可靠的方法是閱讀源代碼。 –