2013-03-08 103 views
7

我對numpy的廣播規則有些困惑。假設你想執行一個更高維數組的軸明智的數量積可以降低1陣列空間維度(基本上執行沿一個軸的加權和):Numpy Array不同維度的廣播

from numpy import * 

A = ones((3,3,2)) 
v = array([1,2]) 

B = zeros((3,3)) 

# V01: this works 
B[0,0] = v.dot(A[0,0]) 

# V02: this works 
B[:,:] = v[0]*A[:,:,0] + v[1]*A[:,:,1] 

# V03: this doesn't 
B[:,:] = v.dot(A[:,:]) 

爲什麼V03不行?

乾杯

回答

4

np.dot(a, b)over the last axis of a and the second-to-last of b操作。所以對於你的問題你的具體情況,你總是可以去:

>>> a.dot(v) 
array([[ 3., 3., 3.], 
     [ 3., 3., 3.], 
     [ 3., 3., 3.]]) 

如果你想保持v.dot(a)順序,你需要得到軸到位,它可以很容易地與np.rollaxis實現:

>>> v.dot(np.rollaxis(a, 2, 1)) 
array([[ 3., 3., 3.], 
     [ 3., 3., 3.], 
     [ 3., 3., 3.]]) 

我不喜歡np.dot太多,除非它是明顯的矩陣或向量乘法,因爲它使用可選out參數時是非常嚴格的輸出D型。 Joe Kington已經提到過它,但是如果你打算做這種事情,那就習慣np.einsum:一旦你掌握了這個語法,它就會減少你花費在重塑事物上的時間最低:

>>> a = np.ones((3, 3, 2)) 
>>> np.einsum('i, jki', v, a) 
array([[ 3., 3., 3.], 
     [ 3., 3., 3.], 
     [ 3., 3., 3.]]) 

不在於它是在這種情況下也有關,但它也快得離譜:

In [4]: %timeit a.dot(v) 
100000 loops, best of 3: 2.43 us per loop 

In [5]: %timeit v.dot(np.rollaxis(a, 2, 1)) 
100000 loops, best of 3: 4.49 us per loop 

In [7]: %timeit np.tensordot(v, a, axes=(0, 2)) 
100000 loops, best of 3: 14.9 us per loop 

In [8]: %timeit np.einsum('i, jki', v, a) 
100000 loops, best of 3: 2.91 us per loop 
2

您可以使用numpy.apply_along_axis()此:

In [35]: np.apply_along_axis(v.dot, 2, A) 
Out[35]: 
array([[ 3., 3., 3.], 
     [ 3., 3., 3.], 
     [ 3., 3., 3.]]) 

究其原因,我認爲V03不起作用的是,它沒有什麼不同來:

B[:,:] = v.dot(A) 

即它試圖計算沿着A的最外軸的點積。

3

在這種特殊情況下,您也可以使用tensordot

import numpy as np 

A = np.ones((3,3,2)) 
v = np.array([1,2]) 

print np.tensordot(v, A, axes=(0, 2)) 

這產生了:

array([[ 3., 3., 3.], 
     [ 3., 3., 3.], 
     [ 3., 3., 3.]]) 

axes=(0,2)表明tensordot應該總結以上在v第一軸線和在A第三軸線。 (也有看einsum,這是更靈活,但很難理解,如果你不習慣的符號。)

如果速度是一個考慮因素,tensordot比使用apply_along_axes爲小數組相當快。

In [14]: A = np.ones((3,3,2)) 

In [15]: v = np.array([1,2]) 

In [16]: %timeit np.tensordot(v, A, axes=(0, 2)) 
10000 loops, best of 3: 21.6 us per loop 

In [17]: %timeit np.apply_along_axis(v.dot, 2, A) 
1000 loops, best of 3: 258 us per loop 

(所不同的是對於大型陣列不太明顯由於恆定的開銷,雖然tensordot是一致更快。)

+0

我發現numpy的tensordot軸完全混亂。你能詳細說明嗎?我想通過一個(3,5)張量來多重(10,3,2)張量來得到一個(2)張量。 – Bob 2015-12-22 22:31:15