似乎在numba中不受支持。在numba jitted函數中動態增長陣列
在nopython模式下使用numba.jit
動態增長數組的最佳方式是什麼?
到目前爲止,我能做的最好的事情是定義並調整jitted函數之外的數組的大小,是否有更好的(和更整齊的)選項?
似乎在numba中不受支持。在numba jitted函數中動態增長陣列
在nopython模式下使用numba.jit
動態增長數組的最佳方式是什麼?
到目前爲止,我能做的最好的事情是定義並調整jitted函數之外的數組的大小,是否有更好的(和更整齊的)選項?
numpy.resize
是pure python function:
import numpy as np
def resize(a, new_shape):
"""I did some minor changes so it all works with just `import numpy as np`."""
if isinstance(new_shape, (int, np.core.numerictypes.integer)):
new_shape = (new_shape,)
a = np.ravel(a)
Na = len(a)
if not Na:
return np.zeros(new_shape, a.dtype)
total_size = np.multiply.reduce(new_shape)
n_copies = int(total_size/Na)
extra = total_size % Na
if total_size == 0:
return a[:0]
if extra != 0:
n_copies = n_copies+1
extra = Na-extra
a = np.concatenate((a,)*n_copies)
if extra > 0:
a = a[:-extra]
return np.reshape(a, new_shape)
對於一維數組,這將是直着自己實現。不幸的是,ND陣列要複雜得多,因爲一些操作在nopython numba函數中不受支持:isinstance
,reshape
和元組乘法。這裏是1D相當於:
import numpy as np
import numba as nb
@nb.njit
def resize(a, new_size):
new = np.zeros(new_size, a.dtype)
idx = 0
while True:
newidx = idx + a.size
if newidx > new_size:
new[idx:] = a[:new_size-newidx]
break
new[idx:newidx] = a
idx = newidx
return new
,當你不希望這樣「重複輸入」行爲,只會用它來增加大小那就更簡單了:
@nb.njit
def resize(a, new_size):
new = np.zeros(new_size, a.dtype)
new[:a.size] = a
return new
這些功能都裝飾與numba.njit
,因此可以在nopython模式中的任何numba函數中調用。
注意的一點是,雖然:一般來說,你不想來調整 - 或者如果你然後確保你選擇有amoritzed O(1)
cost (Wikipedia link)的方法。如果你可以估計最大長度,那麼最好立即預先分配一個正確大小(或略微超額分配)的數組。
通常,我使用的策略是隻分配足夠多的數組存儲以適應計算,然後跟蹤最終使用的索引/索引,然後在返回之前將數組切片降至實際大小。這假定您事先知道您可能將陣列增長到的最大大小。我的想法是,在我的大多數應用程序中,內存很便宜,但是調整大小並在python和jitted函數之間切換很貴。
謝謝,那就是我一直在尋找的。太糟糕了,ND陣列沒有簡單的解決方案。 – nivniv
您可以使用np.empty(new_size,a.dtype)代替非常小且有爭議的性能增益。 – tal