2016-12-07 43 views
3

我在使用OneVsRest分類器對三個分類問題(三個隨機森林)進行分類。每個類的出現都被定義爲我的虛擬整數(出現時爲1,否則爲0)。我想知道是否有一種簡單的替代方法來創建混淆矩陣?正如我遇到的所有方法一樣,以y_pred,y_train = array,shape = [n_samples]的形式獲取參數。理想情況下,我想y_pred,y_train =陣列,形狀= [N_SAMPLES次,n_classes]有沒有一種簡單的方法來獲得多類分類的混淆矩陣? (OneVsRest)

一些示例中,類似於問題的結構:

y_train = np.array([(1,0,0), (1,0,0), (0,0,1), (1,0,0), (0,1,0)]) 
y_pred = np.array([(1,0,0), (0,1,0), (0,0,1), (0,1,0), (1,0,0)]) 


print(metrics.confusion_matrix(y_train, y_pred) 

RETURNS: 多標記指示器不支持

+0

這看起來像多類,而不是多標記。 –

回答

6

我不知道你的想法,因爲你沒有指定你要找的輸出,但這裏有兩種方法,你可以去了解它:每列

1.One混淆矩陣

In [1]: 
for i in range(y_train.shape[1]): 
    print("Col {}".format(i)) 
    print(metrics.confusion_matrix(y_train[:,i], y_pred[:,i])) 
    print("") 

Out[1]: 
Col 0 
[[1 1] 
[2 1]] 

Col 1 
[[2 2] 
[1 0]] 

Col 2 
[[4 0] 
[0 1]] 

2.One混淆矩陣共

對於這一點,我們要扁平化陣列:

In [2]: print(metrics.confusion_matrix(y_train.flatten(), y_pred.flatten())) 

Out[2]: 
[[7 3] 
[3 2]] 
相關問題