2017-02-26 42 views
2

給出兩條曲線的交點我有兩個數據集:(X,Y 1)和(x,y2)上。我想找到這兩條曲線相互交叉的位置。我們的目標是類似這樣的問題:Intersection of two graphs in Python, find the x value:查找(X,Y)數據以高精度在Python

然而,所描述的方法只發現交點到最接近的數據點。我想找到比原始數據間距具有更高精度的曲線的交點。一種選擇是簡單地重新插值到更精細的網格。這是有效的,但是然後精度由我選擇用於重新插值的點的數量決定,這是任意的,並且需要在精度和效率之間進行權衡。

可替換地,我可以使用scipy.optimize.fsolve查找數據集的所述兩個花鍵插補的確切交集。這很好,但它不容易找到多個交點,要求我爲交點提供合理的猜測,並且可能不能很好地縮放。 (最後,我想找到的幾千套(X,Y1,Y2)的交叉點,所以一個有效的算法將是很好的。)

這是我到目前爲止所。任何改進想法?

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.interpolate, scipy.optimize 

x = np.linspace(1, 4, 20) 
y1 = np.sin(x) 
y2 = 0.05*x 

plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1') 
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2') 

idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 

plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method') 

interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1) 
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2) 

new_x = np.linspace(x.min(), x.max(), 100) 
new_y1 = interp1(new_x) 
new_y2 = interp2(new_x) 
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0) 
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data') 

def difference(x): 
    return np.abs(interp1(x) - interp2(x)) 

x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0) 
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method') 

plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left') 

plt.savefig('curve crossing.png', dpi=200) 
plt.show() 

enter image description here

+0

是不是總有精度和效率之間的權衡?您可以繼續插入更細的網格,直到您的答案收斂到可容忍的數量範圍內。 – Crispin

+1

是不是來自網格交叉點的近似信息正是您設置樣條交叉問題所需的信息?我能看到的唯一問題就是如果在單個網格單元中有多個交點。我將運行網格交叉點,然後使用其中的答案來解決樣條交集,使用樣條線限制到發現的網格交點附近的幾個單元格。 – mcdowella

+0

waterboy5281,我想你是對的,在給定相同算法的情況下,通常會在效率和精度之間進行權衡。但是,更好的算法通常既快速又精確。 @mcdowella,我喜歡通過「最近的數據點」法求交點的近似位置,然後使用該信息以使其更容易找到精確交叉點的想法。我會盡力實現這一點。 – DanHickstein

回答

1

最好的(也是最高效的)答案很可能取決於數據集以及它們是如何採樣。但是,對於許多數據集來說,一個很好的近似值是它們在數據點之間幾乎是線性的。因此,我們可以通過原始文章中顯示的「最近的數據點」方法找到交集的大概位置。然後,我們可以使用線性插值來細化最近兩個數據點之間的交點位置。

這種方法是非常快,並與2D numpy的數組,你想同時計算多條曲線的交叉工作,萬一(我想在我的應用程序執行)。

(我借用「How do I compute the intersection point of two lines in Python?」代碼的線性插值。)

from __future__ import division 
import numpy as np 
import matplotlib.pyplot as plt 

def interpolated_intercept(x, y1, y2): 
    """Find the intercept of two curves, given by the same x data""" 

    def intercept(point1, point2, point3, point4): 
     """find the intersection between two lines 
     the first line is defined by the line between point1 and point2 
     the first line is defined by the line between point3 and point4 
     each point is an (x,y) tuple. 

     So, for example, you can find the intersection between 
     intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5) 

     Returns: the intercept, in (x,y) format 
     """  

     def line(p1, p2): 
      A = (p1[1] - p2[1]) 
      B = (p2[0] - p1[0]) 
      C = (p1[0]*p2[1] - p2[0]*p1[1]) 
      return A, B, -C 

     def intersection(L1, L2): 
      D = L1[0] * L2[1] - L1[1] * L2[0] 
      Dx = L1[2] * L2[1] - L1[1] * L2[2] 
      Dy = L1[0] * L2[2] - L1[2] * L2[0] 

      x = Dx/D 
      y = Dy/D 
      return x,y 

     L1 = line([point1[0],point1[1]], [point2[0],point2[1]]) 
     L2 = line([point3[0],point3[1]], [point4[0],point4[1]]) 

     R = intersection(L1, L2) 

     return R 

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1]))) 
    return xc,yc 

def main(): 
    x = np.linspace(1, 4, 20) 
    y1 = np.sin(x) 
    y2 = 0.05*x 

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1') 
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2') 

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method') 

    # new method! 
    xc, yc = interpolated_intercept(x,y1,y2) 
    plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation') 


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left') 

    plt.savefig('curve crossing.png', dpi=200) 
    plt.show() 

if __name__ == '__main__': 
    main() 

Curve crossing

相關問題