2017-06-06 210 views
1

我有兩個數組A(4000,4000),其中只有對角線填充數據,而B(4000,5)填充數據。有沒有辦法將這些數組乘以(點)比numpy.dot(a,b)函數更快的方法?Python numpy矩陣乘以一個對角矩陣

到目前爲止,我發現(A * B.T).T應該更快(其中A是一維(4000,),填充了對角線元素),但事實證明,它的速度大約是速度的兩倍。

有沒有更快的方法來計算B.dot(A)在A是diagnal數組的情況下?

+0

那些是陣列或矩陣?另外,你確定(B * A.T).T? – Divakar

+0

Numpy矩陣,我做了測試(B * AT).T,但只用於一個小矩陣,所以我會嘗試一個大矩陣併發布結果 – LMB

+0

我的意思是A(4000,5),B( 4000,4000),那麼'B * AT'就會有錯誤的錯誤。 – Divakar

回答

1

您可以簡單地提取對角線元素,然後執行廣播元素乘法。

因此,對於B*A替換將是 -

np.multiply(np.diag(B)[:,None], A) 

A.T*B -

np.multiply(A.T,np.diag(B)) 

運行試驗 -

In [273]: # Setup 
    ...: M,N = 4000,5 
    ...: A = np.random.randint(0,9,(M,N)).astype(float) 
    ...: B = np.zeros((M,M),dtype=float) 
    ...: np.fill_diagonal(B, np.random.randint(11,99,(M))) 
    ...: A = np.matrix(A) 
    ...: B = np.matrix(B) 
    ...: 

In [274]: np.allclose(B*A, np.multiply(np.diag(B)[:,None], A)) 
Out[274]: True 

In [275]: %timeit B*A 
10 loops, best of 3: 32.1 ms per loop 

In [276]: %timeit np.multiply(np.diag(B)[:,None], A) 
10000 loops, best of 3: 33 µs per loop 

In [282]: np.allclose(A.T*B, np.multiply(A.T,np.diag(B))) 
Out[282]: True 

In [283]: %timeit A.T*B 
10 loops, best of 3: 24.1 ms per loop 

In [284]: %timeit np.multiply(A.T,np.diag(B)) 
10000 loops, best of 3: 36.2 µs per loop 
0

顯示我的(A的初始權利要求* BT).T變慢是不正確的。

from timeit import default_timer as timer 
import numpy as np 

##### Case 1 
a = np.zeros((4000,4000)) 
np.fill_diagonal(a, 10) 
b = np.ones((4000,5)) 

dot_list = [] 

def time_dot(a,b): 
    start = timer() 
    c = np.dot(a,b) 
    end = timer() 
    return end - start 

for i in range(100): 
    dot_list.append(time_dot(a,b)) 

print np.mean(np.asarray(dot_list)) 

##### Case 2 
a = np.ones((4000,)) 
a = a * 10 
b = np.ones((4000,5)) 

shortcut_list = [] 

def time_quicker(a,b): 
    start = timer() 
    c = (a*b.T).T 
    end = timer() 
    return end - start 

for i in range(100): 
    shortcut_list.append(time_quicker(a,b)) 

print np.mean(np.asarray(shortcut_list)) 


##### Case 3 
a = np.zeros((4000,4000)) #diagonal matrix 
np.fill_diagonal(a, 10) 
b = np.ones((4000,5)) 

case3_list = [] 

def function(a,b): 
    start = timer() 
    np.multiply(b.T,np.diag(a)) 
    end = timer() 
    return end - start 

for i in range(100): 
    case3_list.append(function(a,b)) 

print np.mean(np.asarray(case3_list)) 

結果:

0.119120892431

0.00010633951868

0.00214490709662

所以第二種方法是最快

+0

這個問題陳述了'A(4000,5),B(4000,4000)',這裏你正在使用它們翻轉。認爲你需要解決這個問題或這篇文章。 – Divakar

+0

@Divakar,你是對的,修正了這個問題 – LMB

+0

同樣,在你說的評論中你使用的是NumPy矩陣,並且你在這裏使用的是NumPy數組。 – Divakar