2016-10-21 149 views
2

我是機器學習和scikit中的新成員。我想知道如何用scikit來計算10倍克洛斯濃度的混淆矩陣。我怎樣才能找到y_test和y_pred?混淆矩陣在scikit學習中進行10倍交叉驗證

+0

您需要在sklearn首先使用sklearn.pipeline.Pipeline方法:http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html。然後您需要從sklearn.cross_validation導入KFold。對於混淆矩陣,您可以查看:http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html。在一個很好的功能中結合所有這3個步驟。 – PJay

+0

Hi P Jay。你能幫我更多的示例代碼。 –

回答

2
def plot_confusion_matrix(cm, classes, 
          normalize=False, 
          title='Confusion matrix', 
          cmap=plt.cm.Blues): 
    """ 
    This function prints and plots the confusion matrix. 
    Normalization can be applied by setting `normalize=True`. 
    """ 
    plt.imshow(cm, interpolation='nearest', cmap=cmap) 
    plt.title(title) 
    plt.colorbar() 
    tick_marks = np.arange(len(classes)) 
    plt.xticks(tick_marks, classes, rotation=45) 
    plt.yticks(tick_marks, classes) 

    if normalize: 
     cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] 
     print("Normalized confusion matrix") 
    else: 
     print('Confusion matrix, without normalization') 

    print(cm) 

    thresh = cm.max()/2. 
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 
     plt.text(j, i, cm[i, j], 
       horizontalalignment="center", 
       color="white" if cm[i, j] > thresh else "black") 

    plt.tight_layout() 
    plt.ylabel('True label') 
    plt.xlabel('Predicted label') 



from sklearn import datasets 
from sklearn.cross_validation import cross_val_score 
from sklearn import svm 
from sklearn.metrics import confusion_matrix 
import itertools 
import numpy as np 
import matplotlib.pyplot as plt 
from sklearn import cross_validation 
iris = datasets.load_iris() 
class_names = iris.target_names 
# shape of data is 150 
cv = cross_validation.KFold(150, n_folds=10,shuffle=False,random_state=None) 
for train_index, test_index in cv: 

    X_tr, X_tes = iris.data[train_index], iris.data[test_index] 
    y_tr, y_tes = iris.target[train_index],iris.target[test_index] 
    clf = svm.SVC(kernel='linear', C=1).fit(X_tr, y_tr) 

    y_pred=clf.predict(X_tes) 
    cnf_matrix = confusion_matrix(y_tes, y_pred) 
    np.set_printoptions(precision=2) 

    # Plot non-normalized confusion matrix 
    plt.figure() 
    plot_confusion_matrix(cnf_matrix, classes=class_names, 
         title='Confusion matrix, without normalization') 
    # Plot normalized confusion matrix 
    plt.figure() 
    plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, 
         title='Normalized confusion matrix') 

    plt.show() 
+0

嗨Chandan。此代碼不起作用。 –

+0

請加入這一行:%matplotlib直列 – Chandan

+0

請你告訴我這個問題 – Chandan