2016-04-29 49 views
0

我試圖重新創建這個圖像使用Python給出2類和他們從分類器相關的預測概率。如何在Scikitlearn中繪製S形概率曲線?

我希望看到這樣的事情: sigmoid curve

它不工作,雖然,因爲我得到了大部分直線。 **注:我知道顯示的這些數據目前可疑和/或不好。我需要調整輸入&模型,但想看看情節

基本上,我認爲我會「改正」predict_proba()輸出,所以他們都是關於「0」類(意思是如果它預測爲「1」類,它是「0」類的概率是1-(1classProbability),使得95%預測它是類「1」變成5%變爲類「0」。然後按照我的修正。predicition價值的東西乙狀結腸十歲上下的最終

不幸的是,我結束了這一點: enter image description here

這裏的我的蟒蛇在那裏我試圖(失敗)的一大塊繪製概率乙狀結腸:

########################### 
## I removed my original Python code because it was very, very wrong so as to avoid any confusion. 
########################### 

僅供參考,下面是在Matlab我想要在我的Python模型複製的情節。

%Build the model 
mdl = fitglm(X, Y, 'distr', 'binomial', 'link', 'logit') 
%Build the sigmoid model 
B = mdl.Coefficients{:, 1}; 
Z = mdl.Fitted.LinearPredictor 
yhat = glmval(B, X, 'logit'); 
figure, scatter(Z, yhat), hold on, 
gscatter(Z, zeros(length(X),1)-0.1, Y) % plot original classes 
hold off, xlabel('\bf Z'), grid on, ylim([-0.2 1.05]) 
title('\bf Predicted Probability of each record') 

回答

0

可能有更Python的方式來做到這一點,但這裏是我能想出到底:在這種情況下

(請記住,數據不完全分離的,所以曲線doen't具有傳統的外觀與在0.50點的S形曲線上分離的類。)

############################################################################# 
#### Draws a sigmoid probability plot from prediction results ############### 
############################################################################# 
import matplotlib.pyplot as plt 
import numpy as np 
print ('-'*40) 

# make the predictions (class) and also get the prediction probabilities 
y_train_predict = clf.predict(X_train) 
y_train_predictProbas = clf.predict_proba(X_train) 
y_train_predictProbas = y_train_predictProbas[:, 1] 

y_test_predict = clf.predict(X_test) 
y_test_predictProbas = clf.predict_proba(X_test) 
y_test_predictProbas = y_test_predictProbas[:, 1] 

#Get the thetas from the model 
thetas = clf.coef_[0] 
intercept = clf.intercept_[0] 
print 'thetas=' 
print thetas 
print 'intercept=' 
print intercept 

#Display the predictors and their associated Thetas 
for idx, x in enumerate(thetas): 
    print "Predictor: " + str(labels[idx+1]) + "=" + str(x) 

#append intercept to thetas (because scikitlearn doesn't normally output theta0 
interceptAndThetas = np.append([intercept],thetas) 
X_testWithThetaZero = [] 
for row in X_test: 
    X_testWithThetaZero.append(np.append([1],row)) 

#Calculate the dot product for plotting the sigmoid 
dotProductResult = []  
for idx, x in enumerate(X_testWithThetaZero): 
    dotProductResult.append(np.dot(x, interceptAndThetas))  


fig, ax1 = plt.subplots() 

wrongDotProducts = [] 
rightDotProducts = [] 
#Build the plot 
for idx in range(0,len(dotProductResult)): 
    #plot the predicted value on the sigmoid curve 
    if y_test[idx] == 1: 
     ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['green'],linewidths=0.0) 
    else: 
     ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['black'],linewidths=0.0) 

    #plot the actual 
    if y_test[idx] == 1: 
     ax1.scatter(dotProductResult[idx], y_test[idx], c=['green'],linewidths=0.0) 
     #determine which ones are "wrong" so we can make a histogram 
     if y_test_predictProbas[idx] < 0.5: 
      wrongDotProducts.append(dotProductResult[idx]) 
     else: 
      rightDotProducts.append(dotProductResult[idx]) 
    else: 
     ax1.scatter(dotProductResult[idx], y_test[idx], c=['black'],linewidths=0.0) 
     #determine which ones are "wrong" so we can make a histogram 
     if y_test_predictProbas[idx] > 0.5: 
      wrongDotProducts.append(dotProductResult[idx]) 
     else: 
      rightDotProducts.append(dotProductResult[idx])   

#plt.xlim([-0.05, numInstances + 0.05]) 
plt.ylim([-0.05, 1.05]) 
plt.xlabel('x') 
plt.grid(which="major", axis='both',markevery=0.10) # which='major', 
plt.ylabel('Prediction Probability') 
plt.title('Sigmoid Curve & Histogram of Predictions') 


## Add a histogram to show where the correct/incorrect prediction distributions 
ax2 = ax1.twinx() 
ax2.hist(wrongDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="/", color="black", alpha=0.2) 
ax2.hist(rightDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="\\", color="green", alpha=0.2) 

ax2.set_ylabel('Histogram Count of Actual Class\n1=Green 0=Black') 
ax2.set_xlabel('') 
ax2.set_title('') 
plt.show()  

enter image description here