2017-12-18 86 views
1

我有一個函數來計算1d np.array的有限差分,我想外推到一個n-d數組。Python中的「動態」N維有限差異

功能是這樣的:

def fpp_fourth_order_term(U): 
    """Returns the second derivative of fourth order term without the interval multiplier.""" 
    # U-slices 
    fm2 = values[ :-4] 
    fm1 = values[1:-3] 
    fc0 = values[2:-2] 
    fp1 = values[3:-1] 
    fp2 = values[4: ] 

    return -fm2 + 16*(fm1+fp1) - 30*fc0 - fp2 

它缺少四階乘數(1/(12*h**2)),但這不要緊,因爲分組的條款時,我就會大量繁殖。

我很想把它作爲一個N維擴展。對此,我會做以下修改:

def fpp_fourth_order_term(U, axis=0): 
    """Returns the second derivative of fourth order term along an axis without the interval multiplier.""" 
    # U-slices 

但這裏是問題

fm2 = values[ :-4] 
    fm1 = values[1:-3] 
    fc0 = values[2:-2] 
    fp1 = values[3:-1] 
    fp2 = values[4: ] 

這工作在1D罰款,如果是2D 沿着第一軸例如我將不得不改變是這樣的:

fm2 = values[:-4,:] 
    fm1 = values[1:-3,:] 
    fc0 = values[2:-2,:] 
    fp1 = values[3:-1,:] 
    fp2 = values[4:,:] 

但沿着第二軸將是:

fm2 = values[:,:-4] 
    fm1 = values[:,1:-3] 
    fc0 = values[:,2:-2] 
    fp1 = values[:,3:-1] 
    fp2 = values[:,4:] 

這同樣適用於3d,但有3種可能性,並繼續。如果鄰居設置正確,返回總是有效的。

return -fm2 + 16*(fm1+fp1) - 30*fc0 - fp2 

當然axis不能超過len(U.shape)-1大(我把這個尺寸,有沒有什麼辦法,而不是提取該段?

我如何做這個編碼問題一個優雅和Python的方法呢?

有沒有更好的方式來做到這一點

PS:關於np.diffnp.gradient,那些沒有工作,因爲第一個是一階,第二個我我做了第四階近似。事實上,我很快就完成了這個問題,我也將推廣這個命令。但是,是的,我希望能夠在任何軸上做np.gradient

+0

典雅可能會問太多,但(這些)(https://stackoverflow.com/q/42817508/7207392)[鏈接](https://stackoverflow.com/q/24398708/7207392)可以幫助你開始。 –

+0

'切片'模式不支持**花梢**索引作爲'U [:,1:-3]'另一種方式('np.take')依稀支持,我需要構建範圍(自'range(1,-1)'是沒有意義的),也是'np.take'複製數據(我想避免的,但是謝謝,我正在進步 – Lin

+0

'''fc0 = values.take(indices =範圍(2,values.shape [axis] -2),axis = axis)',可以工作,但是複製陣列5次會讓我記憶猶新。這些陣列會變大。 : -/ – Lin

回答

2

一個簡單而有效的解決方案是,在開始和結束你的程序的使用swapaxes

import numpy as np 

def f(values, axis=-1): 
    values = values.swapaxes(0, axis) 

    fm2 = values[ :-4] 
    fm1 = values[1:-3] 
    fc0 = values[2:-2] 
    fp1 = values[3:-1] 
    fp2 = values[4: ] 

    return (-fm2 + 16*(fm1+fp1) - 30*fc0 - fp2).swapaxes(0, axis) 

a = (np.arange(4*7*8)**3).reshape(4,7,8) 
res = f(a, axis=1) 
print(res) 
print(res.flags) 

輸出:

# [[[ 73728 78336 82944 87552 92160 96768 101376 105984] 
# [110592 115200 119808 124416 129024 133632 138240 142848] 
# [147456 152064 156672 161280 165888 170496 175104 179712]] 

# [[331776 336384 340992 345600 350208 354816 359424 364032] 
# [368640 373248 377856 382464 387072 391680 396288 400896] 
# [405504 410112 414720 419328 423936 428544 433152 437760]] 

# [[589824 594432 599040 603648 608256 612864 617472 622080] 
# [626688 631296 635904 640512 645120 649728 654336 658944] 
# [663552 668160 672768 677376 681984 686592 691200 695808]] 

# [[847872 852480 857088 861696 866304 870912 875520 880128] 
# [884736 889344 893952 898560 903168 907776 912384 916992] 
# [921600 926208 930816 935424 940032 944640 949248 953856]]] 

其結果是,即使是連續的。

# C_CONTIGUOUS : True 
# F_CONTIGUOUS : False 
# OWNDATA : False 
# WRITEABLE : True 
# ALIGNED : True 
# UPDATEIFCOPY : False