2016-03-30 56 views
3

假設我有一個N維numpy數組x和一個(N-1)維索引數組m(例如,m = x.argmax(axis=-1))。我想構造(N-1)維數組y,使得y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]](對於上面的argmax示例,它將等於y = x.max(axis=-1))。 n = 3時,我可以實現我想要numpy索引(與最大/ argmax有關)

y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m] 

的問題是,我怎麼一個任意N做到這一點?

回答

1

下面是使用reshapinglinear indexing處理任意尺寸的多維數組的一個方法 -

shp = x.shape[:-1] 
n_ele = np.prod(shp) 
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp) 

讓我們用6 dimensions一個ndarray樣本的情況下,讓我們說,我們正在使用m = x.argmax(axis=-1)索引到最後尺寸。所以,輸出將是x.max(-1)。讓我們驗證這一點對於提出的解決方案 -

In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4)) 

In [122]: m = x.argmax(axis=-1) 

In [123]: shp = x.shape[:-1] 
    ...: n_ele = np.prod(shp) 
    ...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp) 
    ...: 

In [124]: np.allclose(x.max(-1),y_out) 
Out[124]: True 

我喜歡@B. M.'s solution其優雅。所以,這裏的運行時測試,以衡量這兩個 -

def reshape_based(x,m): 
    shp = x.shape[:-1] 
    n_ele = np.prod(shp) 
    return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp) 

def indices_based(x,m): ## @B. M.'s solution 
    firstdims=np.indices(x.shape[:-1]) 
    ind=tuple(firstdims)+(m,) 
    return x[ind] 

計時 -

In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5)) 
    ...: m = x.argmax(axis=-1) 
    ...: 

In [153]: %timeit indices_based(x,m) 
10 loops, best of 3: 30.2 ms per loop 

In [154]: %timeit reshape_based(x,m) 
100 loops, best of 3: 5.14 ms per loop 
+0

謝謝!這工作。我只會在最後一行加上'.reshape(shp)' – BlindDriver

+0

@BindDriver是的,這是一個錯字,編輯了那個錯誤。 – Divakar

2

您可以使用indices

firstdims=np.indices(x.shape[:-1]) 

並添加您:

ind=tuple(firstdims)+(m,) 

然後x[ind]是你想要的。

In [228]: allclose(x.max(-1),x[ind]) 
Out[228]: True 
+0

更加優雅!做那個'np.indices'? – Divakar

+0

不錯!不知道這個索引函數。 – BlindDriver