2017-08-31 229 views
1

給出一個三維陣列和二維陣列的最終軸,相乘的elementwise超過兩個陣列

a = np.arange(10*4*3).reshape((10,4,3)) 
b = np.arange(30).reshape((10,3)) 

如何運行跨過每個的最終軸的elementwise乘法,導致c其中c具有形狀.shape作爲a?即

c[0] = a[0] * b[0] 
c[1] = a[1] * b[1] 
# ... 
c[i] = a[i] * b[i] 

回答

2

沒有任何款項,減少參與,一個簡單的broadcasting將與np.newaxis/None延伸b3D後真正有效 -

a*b[:,None,:] # or simply a*b[:,None] 

運行測試 -

In [531]: a = np.arange(10*4*3).reshape((10,4,3)) 
    ...: b = np.arange(30).reshape((10,3)) 
    ...: 

In [532]: %timeit np.einsum('ijk,ik->ijk', a, b) #@Brad Solomon's soln 
    ...: %timeit a*b[:,None] 
    ...: 
100000 loops, best of 3: 1.79 µs per loop 
1000000 loops, best of 3: 1.66 µs per loop 

In [525]: a = np.random.rand(100,100,100) 

In [526]: b = np.random.rand(100,100) 

In [527]: %timeit np.einsum('ijk,ik->ijk', a, b) 
    ...: %timeit a*b[:,None] 
    ...: 
1000 loops, best of 3: 1.53 ms per loop 
1000 loops, best of 3: 1.08 ms per loop 

In [528]: a = np.random.rand(400,400,400) 

In [529]: b = np.random.rand(400,400) 

In [530]: %timeit np.einsum('ijk,ik->ijk', a, b) 
    ...: %timeit a*b[:,None] 
    ...: 
10 loops, best of 3: 128 ms per loop 
10 loops, best of 3: 94.8 ms per loop 
+0

@BradSolomon這是正確的。增加了時間來確認這些。 – Divakar

1

使用np.einsum

c = np.einsum('ijk,ik->ijk', a, b) 

快速檢查:

print(np.allclose(c[0], a[0] * b[0])) 
print(np.allclose(c[1], a[1] * b[1])) 
print(np.allclose(c[-1], a[-1] * b[-1])) 
# True 
# True 
# True