np.broadcast_to
可'重塑'A
匹配B
;那麼你可以一起迭代兩者。它使用striding,所以沒有實際增加內存使用。
In [370]: def f(a,b):
...: assert(a.shape==(1,3))
...: assert(b.shape==(1,3))
...: return a+b
...:
In [371]: B=np.arange(12).reshape(4,3)
In [372]: A=np.arange(3).reshape(1,3)
In [373]: np.broadcast_to(A, B.shape) # (1,3) to (4,3)
Out[373]:
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
In [374]: np.broadcast_to(B, B.shape) # no change with (4,3)
Out[374]:
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
我通常使用列表解析而不是圖:
In [375]: [f(np.atleast_2d(a),np.atleast_2d(b)) for a,b in zip(np.broadcast_to(A,B.shape),B)]
Out[375]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
In [376]: [f(np.atleast_2d(a),np.atleast_2d(b)) for a,b in zip(np.broadcast_to(B,B.shape),B)]
Out[376]:
[array([[0, 2, 4]]),
array([[ 6, 8, 10]]),
array([[12, 14, 16]]),
array([[18, 20, 22]])]
迭代在2D陣列產生一維數組的列表,因此需要np.atleast_2d
滿足我f
斷言。如果f
也接受(3,)輸入,我將不需要。
或者與map
:
In [377]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), np.broadcast_to(B,B.shape),B)
Out[377]: <map at 0xb14f4c6c>
In [378]: list(_)
Out[378]:
[array([[0, 2, 4]]),
array([[ 6, 8, 10]]),
array([[12, 14, 16]]),
array([[18, 20, 22]])]
In [379]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), np.broadcast_to(A,B.shape),B)
Out[379]: <map at 0xb0871a8c>
In [380]: list(_)
Out[380]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
np.vectorize
和np.frompyfunc
處理這種廣播的爲好,但它們被設計爲帶標量,而不是一維數組功能。
隨着broadcast_arrays
我可以平等地對待兩個數組:
In [386]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), *np.broadcast_arrays(B,A))
Out[386]: <map at 0xb69851ac>
In [387]: list(_)
Out[387]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
更一般地,A
和B
可以是任何生產所需(N,3)
陣列。我可以通過生產(N,1,3)
陣列擺脫atleast_2d
的使用:
In [397]: map(f, *np.broadcast_arrays(np.arange(3)[None,None,:], np.arange(0,40,10)[:,None,None]))
Out[397]: <map at 0xb08b562c>
In [398]: list(_)
Out[398]:
[array([[0, 1, 2]]),
array([[10, 11, 12]]),
array([[20, 21, 22]]),
array([[30, 31, 32]])]