2016-06-21 29 views
3

通常,當我們知道應該在哪裏插入新軸時,我們可以做a[:, np.newaxis,...]。有沒有什麼好的方法可以將軸插入某個軸?在NumPy數組中插入可變位置的新軸

這是我現在怎麼做。我想一定是有比這更好的方法:

def addNewAxisAt(x, axis): 
    _s = list(x.shape) 
    _s.insert(axis, 1) 
    return x.reshape(tuple(_s)) 

def addNewAxisAt2(x, axis): 
    ind = [slice(None)]*x.ndim 
    ind.insert(axis, np.newaxis) 
    return x[ind] 

回答

1

np.insert確實

slobj = [slice(None)]*ndim 
... 
slobj[axis] = slice(None, index) 
... 
new[slobj] = arr[slobj2] 

就像你它構造片的名單,並修改一個或多個元素。

apply_along_axis構建一個陣列,並將其轉換爲索引元組

outarr[tuple(i.tolist())] = res 

其他numpy的功能,這種工作方式爲好。

我的建議是使初始列表大到足以容納None。然後,我並不需要使用insert

In [1076]: x=np.ones((3,2,4),int) 

In [1077]: ind=[slice(None)]*(x.ndim+1) 

In [1078]: ind[2]=None 

In [1080]: x[ind].shape 
Out[1080]: (3, 2, 1, 4) 

In [1081]: x[tuple(ind)].shape # sometimes converting a list to tuple is wise 
Out[1081]: (3, 2, 1, 4) 

原來有一個np.expand_dims

In [1090]: np.expand_dims(x,2).shape 
Out[1090]: (3, 2, 1, 4) 

它使用reshape像你一樣,但創建了元組拼接新的形狀。

def expand_dims(a, axis): 
    a = asarray(a) 
    shape = a.shape 
    if axis < 0: 
     axis = axis + len(shape) + 1 
    return a.reshape(shape[:axis] + (1,) + shape[axis:]) 

時間表並沒有告訴我有關哪個更好。它們是2μs範圍,只需將函數中的代碼包裝起來就可以改變。

3

這個單維(dim length = 1)可以添加作爲形狀標準,原來的陣列形狀np.insert,從而直接改變其形狀,像這樣 -

x.shape = np.insert(x.shape,axis,1) 

那麼,我們不妨延長這一邀請一個以上的新的座標軸帶着幾分np.diffnp.cumsum伎倆,像這樣 -

insert_idx = (np.diff(np.append(0,axis))-1).cumsum()+1 
x.shape = np.insert(x.shape,insert_idx,1) 

樣品試驗 -

In [151]: def addNewAxisAt(x, axis): 
    ...:  insert_idx = (np.diff(np.append(0,axis))-1).cumsum()+1 
    ...:  x.shape = np.insert(x.shape,insert_idx,1) 
    ...:  

In [152]: A = np.random.rand(4,5) 

In [153]: addNewAxisAt(A, axis=1) 

In [154]: A.shape 
Out[154]: (4, 1, 5) 

In [155]: A = np.random.rand(5,6,8,9,4,2) 

In [156]: addNewAxisAt(A, axis=5) 

In [157]: A.shape 
Out[157]: (5, 6, 8, 9, 4, 1, 2) 

In [158]: A = np.random.rand(5,6,8,9,4,2,6,7) 

In [159]: addNewAxisAt(A, axis=(1,3,4,6)) 

In [160]: A.shape 
Out[160]: (5, 1, 6, 1, 1, 8, 1, 9, 4, 2, 6, 7) 
+0

哦,當然!很容易忘記,ndarray的形狀是一個可設置的屬性 – wim

+0

還有'numpy.reshape',它可以以非變化的方式做到這一點。 – user2357112

+0

@ user2357112是!通過使用由代碼中'np.insert'輸出的新形狀信息。 – Divakar

0

至於我可以告訴你可以只使用元組索引和省略號對象

例如,擺在第4位,這將是這樣的:

>>> a = np.ones((3,3,3,3,3,3,3)) 
>>> a.shape 
(3, 3, 3, 3, 3, 3, 3) 
>>> t = tuple([slice(None)]*4 + [np.newaxis, Ellipsis]) 
>>> a[t].shape 
(3, 3, 3, 3, 1, 3, 3, 3) 

編輯:我的解決方法是跛腳,更喜歡answer from Divakar