2014-05-09 36 views
14

我有兩個問題,瞭解從scikit學習決策樹的結果。例如,這是我的決策樹之一:如何解釋從決策樹scikit學習

enter image description here 我的問題是,我該如何使用樹?

第一個問題是:若樣品滿足的條件,然後將其轉到LEFT分支(如果存在的話),否則它會RIGHT。在我的情況下,如果一個樣本的X [7]> 63521.3984。然後樣品將進入綠色框。正確?

第二個問題是:當一個樣品到達葉節點,我怎樣才能知道它所屬的類別?在這個例子中,我有三個類別進行分類。在紅色框中,分別有91,212和113個樣本滿足條件。但是,我怎樣才能決定這個類別呢? 我知道有一個函數clf.predict(樣品)告訴類別。我可以做圖嗎? 非常感謝。

+1

出於好奇,你是如何繪製決策樹的? – Matt

+4

首先將樹導出爲JSON格式(參見[鏈接](http://www.garysieling.com/blog/rending-scikit-decision-trees-d3-js)),然後使用d3.js繪製該樹。或者你可以直接使用嵌入式函數:'tree.export_graphviz(clf,out_file = your_out_file,feature_names = your_feature_names)'希望它能起作用,@Matt –

回答

21

value線在每個盒子告訴你很多樣品在該節點落入每個類別,爲了如何。這就是爲什麼在每個框中,value中的數字合計爲sample中顯示的數字。例如,在你的紅色框中,91 + 212 + 113 = 416。因此,這意味着如果到達此節點,則類別1中有91個數據點,類別2中有212個數據點,類別3中有113個。

如果您要預測到達該葉節點的新數據點的結果在決策樹中,您會預測類別2,因爲這是該節點上樣本的最常見類別。

+0

我有興趣知道哪個值屬於哪個類。 'DecisionTreeClassifier.classes'持有這個信息。 – ezdazuzena

+0

(有用的答案:爲了澄清使用python索引,儘管:紅色框中的樣本登陸將被預測(計數212)爲類別1,而不是類別0(91)或類別2(113):-)) –

0

根據這本書「學習scikit學習:機器在Python學習」,決策樹表示一系列的基於訓練數據進行決策。

!(http://i.imgur.com/vM9fJLy.png

爲實例進行分類,我們應該回答每個節點的問題。例如,性別< = 0.5? (我們是在談論一個女人?)。 如果答案是肯定的,則轉到樹中的左側子節點;否則你去右邊的子節點。你一直在回答問題(她是在第三堂課嗎?她是在第一堂課嗎?她是13歲以下的?),直到你到達一片葉子。 當您在那裏時,預測對應於具有大多數實例的目標類

2

第一個問題: 是的,你的邏輯是正確的。左邊的節點是True,右邊的節點是False。這是違反直覺的;真實通常意味着一個較小的值。

第二個問題: 這個問題最好通過用pydotplus將圖形可視化爲圖來解決。 tree.export_graphviz()的'class_names'屬性將爲每個節點的大多數類添加一個類聲明。代碼在iPython中執行。

from sklearn.datasets import load_iris 
from sklearn import tree 
iris = load_iris() 
clf2 = tree.DecisionTreeClassifier() 
clf2 = clf2.fit(iris.data, iris.target) 

with open("iris.dot", 'w') as f: 
    f = tree.export_graphviz(clf, out_file=f) 

import os 
os.unlink('iris.dot') 

import pydotplus 
dot_data = tree.export_graphviz(clf2, out_file=None) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 
graph2.write_pdf("iris.pdf") 

from IPython.display import Image 
dot_data = tree.export_graphviz(clf2, out_file=None, 
        feature_names=iris.feature_names, 
        class_names=iris.target_names, 
        filled=True, rounded=True, # leaves_parallel=True, 
        special_characters=True) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 

## Color of nodes 
nodes = graph2.get_node_list() 

for node in nodes: 
    if node.get_label(): 
     values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]; 
     color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],} 
     values = color[values.index(max(values))]; # print(values) 
     color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color) 
     node.set_fillcolor(color) 
# 

Image(graph2.create_png()) 

enter image description here

作爲用於確定的類別在葉,你的例子不具有葉與單個類,如虹膜數據集一樣。這很常見,可能需要過度擬合模型才能獲得這樣的結果。類的離散分佈是許多交叉驗證模型的最佳結果。

享受代碼!