2013-10-09 128 views
5

默認矩陣乘法計算定製積numpy的矩陣乘法

c[i,j] = sum(a[i,k] * b[k,j]) 

我想使用的,而不是點產品的定製式求得

c[i,j] = sum(a[i,k] == b[k,j]) 

是否有一個有效的如何在numpy中做到這一點?

回答

4

您可以使用廣播:

c = sum(a[...,np.newaxis]*b[np.newaxis,...],axis=1) # == np.dot(a,b) 

c = sum(a[...,np.newaxis]==b[np.newaxis,...],axis=1) 

我列入bnewaxis只是明確表示數組如何擴大。還有其他一些向數組添加維度的方法(重塑,重複等),但效果是一樣的。將ab展開成相同的形狀以逐元素相乘(或==),然後求和正確的軸。

+0

+1不錯,謝謝 –

+0

謝謝。這非常整齊。 –