2014-04-25 229 views
1

我正在使用confusion matrix來檢查分類器的性能。如何解釋Scikit-learn混淆矩陣

我正在使用Scikit-Learn,我有點困惑。我如何解讀

from sklearn.metrics import confusion_matrix 
>>> y_true = [2, 0, 2, 2, 0, 1] 
>>> y_pred = [0, 0, 2, 2, 0, 2] 
>>> confusion_matrix(y_true, y_pred) 
array([[2, 0, 0], 
     [0, 0, 1], 
     [1, 0, 2]]) 

我該如何判斷這個預測值是好還是不好。

回答

1

判斷分類器好壞的最簡單方法就是使用一些標準錯誤度量(例如Mean squared error)來計算錯誤。我想你的例子是從Scikit的documentation複製的,所以我假設你已經閱讀了定義。

我們在這裏有三類:0,12。在對角線上,混淆矩陣告訴你,一個特定類別被正確預測的頻率。因此,從對角線2 0 2可以說,具有索引0的分類被正確分類了2次,索引1的分類從未被正確預測,並且具有索引2的分類被正確預測了2次。

在對角線下方和上方有數字,告訴您索引等於元素行號的類被分類爲索引等於矩陣列的類。例如,如果您查看第一列,則在對角線下有:0 1(位於矩陣的左下角)。較低的1告訴您,索引爲2(最後一行)的班級曾被錯誤地歸類爲0(第一列)。這對應於您的y_true中有一個標籤爲2的樣本,並被歸類爲0。這發生在第一個樣本上。

如果您從混淆矩陣中總結所有數字,則會得到測試樣本的數量(2 + 2 + 1 + 1 = 6 - 等於y_truey_pred的長度)。如果對行進行求和,您將得到每個標籤的樣本數量:如您所能驗證的那樣,確實在y_pred中有兩個0 s,一個1和三個2 s。

例如,如果您將矩陣元素除以該數字,則可以看出,例如,具有標籤2的類被正確識別,準確度爲〜66%,並且在1/3的情況下它是混淆的(因此名稱),標籤爲0

TL; DR:

雖然單數量的錯誤措施,衡量整體性能,混淆矩陣,你可以決定是否(舉例):

  • 您的分類只是一切

  • 或者它可以很好地處理一些類,有些則不然(這會給你一個提示,看看你的數據的這個特定部分,並觀察這些情況下分類器的行爲)

  • 它做得很好,但經常混淆標籤A和B.例如,對於線性分類器,如果這些類可線性分離,則可能需要檢查。

等等