2014-01-18 141 views
3

快速插值這個問題類似於前問題回答Fast interpolation over 3D array,但解決不了我的問題。在三維陣列三維原點X

我有尺寸(時間,高度,緯度,經度),標記爲y.shape=(nt, nalt, nlat, nlon)一個四維陣列。 x是高度並隨(時間,緯度,經度)而變化,這意味着x.shape = (nt, nalt, nlat, nlon)。我想插入每個(nt,nlat,nlon)的高度。插值的x_new應該是1d,不會隨(時間,緯度,經度)而改變。

我用numpy.interp,一樣scipy.interpolate.interp1d,想想在原職的答案。這些答案我不能減少循環。

我只能這樣做:

# y is a 4D ndarray 
# x is a 4D ndarray 
# new_y is a 4D array 
for i in range(nlon): 
    for j in range(nlat): 
     for k in range(nt): 
      y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i]) 

這些循環使這個插值太慢了計算。會有人有好點子嗎?幫助將不勝感激。

+0

如果一個new_x是按升序排列,我想你可以寫一個用Cython功能來提高計算速度。 – HYRY

+0

@HYRY new_x按升序排列。但是原來的x可能不是從小到大的順序。這是否可用?我不熟悉Cython,也許需要一些時間來學習它,或者可以在前一篇文章中學習tiago的代碼。 – Hao

+0

但是,'numpy.interp'需要'x'按升序排列,所以需要先將x和y按x排序。 – HYRY

回答

1

這是通過使用numba我的解決方案,它是關於快三倍。

創建測試數據第一,x需要按升序排列:

import numpy as np 
rows = 200000 
cols = 66 
new_cols = 69 
x = np.random.rand(rows, cols) 
x.sort(axis=-1) 
y = np.random.rand(rows, cols) 
nx = np.random.rand(new_cols) 
nx.sort() 

做20萬次口譯在numpy的:

%%time 
ny = np.empty((x.shape[0], len(nx))) 
for i in range(len(x)): 
    ny[i] = np.interp(nx, x[i], y[i]) 

我用的合併方法,而不是二進制搜索方法,因爲nx是順序的,並且nx的長度與x大致相同。

  • interp()使用二進制搜索,時間複雜度爲O(len(nx)*log2(len(x))
  • 合併方法:時間複雜度爲O(len(nx) + len(x))

這裏是numba代碼:

import numba 

@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])") 
def interp2(x, xp, fp, f): 
    n = len(x) 
    n2 = len(xp) 
    j = 0 
    i = 0 
    while x[i] <= xp[0]: 
     f[i] = fp[0] 
     i += 1 

    slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])   
    while i < n: 
     if x[i] >= xp[j] and x[i] < xp[j+1]: 
      f[i] = slope*(x[i] - xp[j]) + fp[j] 
      i += 1 
      continue 
     j += 1 
     if j + 1 == n2: 
      break 
     slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j]) 

    while i < n: 
     f[i] = fp[n2-1] 
     i += 1 

@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])") 
def multi_interp(x, xp, fp): 
    nrows = xp.shape[0] 
    f = np.empty((nrows, x.shape[0])) 
    for i in range(nrows): 
     interp2(x, xp[i, :], fp[i, :], f[i, :]) 
    return f 

然後調用numba功能:

%%time 
ny2 = multi_interp(nx, x, y) 

檢查結果:

np.allclose(ny, ny2) 

在我的電腦,時間爲:

python version: 3.41 s 
numba version: 1.04 s 

這種方法需要一個數組,其中最後一個軸是要interp()軸。

+0

非常感謝。那很棒!我明天會檢查它。 – Hao