2013-09-28 187 views

回答

64
import numpy as np 
import matplotlib.pyplot as plt 

# sample data 
x = np.arange(10) 
y = 5*x + 10 

# fit with np.polyfit 
m, b = np.polyfit(x, y, 1) 

plt.plot(x, y, '.') 
plt.plot(x, m*x + b, '-') 

enter image description here

19

我偏愛scikits.statsmodels。下面一個例子:

import statsmodels.api as sm 
import numpy as np 
import matplotlib.pyplot as plt 

X = np.random.rand(100) 
Y = X + np.random.rand(100)*0.1 

results = sm.OLS(Y,sm.add_constant(X)).fit() 

print results.summary() 

plt.scatter(X,Y) 

X_plot = np.linspace(0,1,100) 
plt.plot(X_plot, X_plot*results.params[0] + results.params[1]) 

plt.show() 

唯一棘手的部分是sm.add_constant(X)它增加了那些以X的列,以獲得截距項。

 Summary of Regression Results 
======================================= 
| Dependent Variable:   ['y']| 
| Model:       OLS| 
| Method:    Least Squares| 
| Date:    Sat, 28 Sep 2013| 
| Time:      09:22:59| 
| # obs:       100.0| 
| Df residuals:     98.0| 
| Df model:      1.0| 
============================================================================== 
|     coefficient  std. error t-statistic   prob. | 
------------------------------------------------------------------------------ 
| x1      1.007  0.008466  118.9032   0.0000 | 
| const     0.05165  0.005138  10.0515   0.0000 | 
============================================================================== 
|       Models stats      Residual stats | 
------------------------------------------------------------------------------ 
| R-squared:      0.9931 Durbin-Watson:    1.484 | 
| Adjusted R-squared:   0.9930 Omnibus:     12.16 | 
| F-statistic:    1.414e+04 Prob(Omnibus):   0.002294 | 
| Prob (F-statistic):  9.137e-108 JB:      0.6818 | 
| Log likelihood:     223.8 Prob(JB):     0.7111 | 
| AIC criterion:     -443.7 Skew:      -0.2064 | 
| BIC criterion:     -438.5 Kurtosis:     2.048 | 
------------------------------------------------------------------------------ 

example plot

+2

我看起來身材不同;線路在錯誤的地方;高於 – David

+2

@David:params數組繞錯了方向。試試: plt.plot(X_plot,X_plot * results.params [1] + results.params [0])。或者,甚至更好:作爲第一個公式假設y是線性的plt.plot(X,results.fittedvalues)是x,儘管在這裏並不總是這樣。 – Ian

8

另一種方式來做到這一點,利用axes.get_xlim()

import matplotlib.pyplot as plt 
import numpy as np 

def scatter_plot_with_correlation_line(x, y, graph_filepath): 
    ''' 
    http://stackoverflow.com/a/34571821/395857 
    x does not have to be ordered. 
    ''' 
    # Scatter plot 
    plt.scatter(x, y) 

    # Add correlation line 
    axes = plt.gca() 
    m, b = np.polyfit(x, y, 1) 
    X_plot = np.linspace(axes.get_xlim()[0],axes.get_xlim()[1],100) 
    plt.plot(X_plot, m*X_plot + b, '-') 

    # Save figure 
    plt.savefig(graph_filepath, dpi=300, format='png', bbox_inches='tight') 

def main(): 
    # Data 
    x = np.random.rand(100) 
    y = x + np.random.rand(100)*0.1 

    # Plot 
    scatter_plot_with_correlation_line(x, y, 'scatter_plot.png') 

if __name__ == "__main__": 
    main() 
    #cProfile.run('main()') # if you want to do some profiling 

enter image description here

9

this excellent answer的單行版本繪製最佳擬合線是:

plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x))) 

使用np.unique(x)而不是x可處理x未排序或具有重複值的情況。

致電poly1d是寫出m*x + b的替代方法,如this other excellent answer

+1

嗨,我的x和y值是使用'numpy.asarray'從列表轉換而來的數組。當我添加這行代碼時,我會在散點圖上看到幾行而不是一行。可能是什麼原因? – artre

+1

@artre感謝您提出這個問題。如果'x'沒有排序或者具有重複值,可能會發生這種情況。我編輯了答案。 –

2
plt.plot(X_plot, X_plot*results.params[0] + results.params[1]) 

plt.plot(X_plot, X_plot*results.params[1] + results.params[0])