2016-10-17 71 views
2

我在做MNIST數據集並試圖獲得我的兩個向量w_i(ith class)a_k(kth sample)的外積。python中的多維外部產品

w_i,對於i = 0...9,具有784個座標。

a_k,爲k = 1...n,也有784座標。

我創建了兩個數組w_ija_ij,它們包含全部十個類和k個樣本。的形狀是(10,784),並且a_ij是(n,784)。

我試圖得到的結果是這樣的:

[[w_0 dot a_1, w_0 dot a_2, ... , w_0 dot a_n], # (first row) 
[w_1 dot a_1, w_1 dot a_2, ..., w_1 dot a_n], # (second row) 
..., 
[w_9 dot a_1, ..., w_9 dot a_n]] # (nth row) 

因此陣列的形狀應該像(10, n)。我試圖使用scipy.outer(w_ij, a_k)scipy.multiply.outer(w_ij, a_k)。但是,它導致我的結果是形狀爲(7840, 784*n)。有人能指引我走向正確的道路嗎?

+0

注意:如果我可以得到沒有使用任何循環的結果,我被挑戰。所以我試圖避免使用它。 – whh1294

回答

3

它看起來像你想以下幾點:

res = np.einsum('pi,qi->pq', w, a) 

這是簡寫形式,在索引符號如下:

res[p,q] = w[p,i]*a[q,i] 

這種表示法,慣例是總結在這做的所有指標不出現在輸出中


但是,請注意ij,jk->ik是理由t是標準矩陣乘積,而ij->ji只是矩陣轉置。所以我們可以簡化如下:

np.einsum('pi,qi->pq', w, a) # as before 
np.einsum('pi,iq->pq', w, a.T) # transpose and swapping indices cancel out 
np.einsum('ij,jk->ik', w, a.T) # index names don't matter 
w @ a.T      # wait a sec, this is just matrix multiplication (python 3.5+) 
+1

或'np.dot(w,a.T)'? – hpaulj

+0

感謝您的提示! – whh1294