下面是使用reshaping
和linear indexing
處理任意尺寸的多維數組的一個方法 -
shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
讓我們用6 dimensions
一個ndarray
樣本的情況下,讓我們說,我們正在使用m = x.argmax(axis=-1)
索引到最後尺寸。所以,輸出將是x.max(-1)
。讓我們驗證這一點對於提出的解決方案 -
In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))
In [122]: m = x.argmax(axis=-1)
In [123]: shp = x.shape[:-1]
...: n_ele = np.prod(shp)
...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
...:
In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True
我喜歡@B. M.'s
solution其優雅。所以,這裏的運行時測試,以衡量這兩個 -
def reshape_based(x,m):
shp = x.shape[:-1]
n_ele = np.prod(shp)
return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
def indices_based(x,m): ## @B. M.'s solution
firstdims=np.indices(x.shape[:-1])
ind=tuple(firstdims)+(m,)
return x[ind]
計時 -
In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
...: m = x.argmax(axis=-1)
...:
In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop
In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop
謝謝!這工作。我只會在最後一行加上'.reshape(shp)' – BlindDriver
@BindDriver是的,這是一個錯字,編輯了那個錯誤。 – Divakar