2016-12-15 57 views
0

工作在Kaggle泰坦尼克號數據集。我試圖更好地理解決策樹,我已經很好地使用了線性迴歸,但從來沒有決策樹。我正在嘗試爲我的樹在Python中創建一個可視化文件。有些東西雖然不起作用。在下面檢查我的代碼。決策樹與SKlearn和可視化

import pandas as pd 
from sklearn import tree 
from sklearn.datasets import load_iris 
import numpy as np 


train_file='.......\RUN.csv' 
train=pd.read_csv(train_file) 

#impute number values and missing values 
train["Sex"][train["Sex"] == "male"] = 0 
train["Sex"][train["Sex"] == "female"] = 1 
train["Embarked"] = train["Embarked"].fillna("S") 
train["Embarked"][train["Embarked"] == "S"]= 0 
train["Embarked"][train["Embarked"] == "C"]= 1 
train["Embarked"][train["Embarked"] == "Q"]= 2 
train["Age"] = train["Age"].fillna(train["Age"].median()) 
train["Pclass"] = train["Pclass"].fillna(train["Pclass"].median()) 
train["Fare"] = train["Fare"].fillna(train["Fare"].median()) 

target = train["Survived"].values 
features_one = train[["Pclass", "Sex", "Age", "Fare","SibSp","Parch","Embarked"]].values 


# Fit your first decision tree: my_tree_one 
my_tree_one = tree.DecisionTreeClassifier(max_depth = 10, min_samples_split = 5, random_state = 1) 

iris=load_iris() 

my_tree_one = my_tree_one.fit(features_one, target) 

tree.export_graphviz(my_tree_one, out_file='tree.dot') 

我該如何看到決策樹?嘗試將其可視化。

幫助感謝!

回答

2

你檢查:http://scikit-learn.org/stable/modules/tree.html提到如何繪製樹PNG圖像:

from IPython.display import Image 
import pydotplus 
dot_data = tree.export_graphviz(my_tree_one, out_file='tree.dot') 
graph = pydotplus.graph_from_dot_data(dot_data) ` 
Image(graph.create_png()) 
+0

>>> import os >>> os.unlink('iris.dot') –

+0

I t說這樣做^。但是,只是刪除該文件。有任何想法嗎?我也沒有pydotplus。我試着用pip下載它,但沒有奏效。 –

+0

我認爲問題是Graphiz,你應該下載它:http://www.graphviz.org/Download..php http://stackoverflow.com/questions/18438997/why-is-pydot-unable-to-find -graphvizs-可執行文件,在窗口-8。首先安裝graphiz然後pydot。或者使用linux。稍後我會回到它。 – Roxanne

0

維基百科:

的DOT語言定義的圖形,但不提供用於呈現設施圖形。有跡象表明,可以用來渲染,查看和操作的DOT語言圖形幾個方案:

的Graphviz - 庫和工具的集合,操作和渲染圖

Canviz - 一個JavaScript庫,用於渲染點文件。

Viz.js - 一個簡單的Graphviz JavaScript客戶

拉帕 - 的Graphviz的局部端口到Java [4] [5]

Beluging - Python & Google雲基於DOT和Beluga擴展的查看器。 [1]

鬱金香可以導入點文件進行分析

的OmniGraffle可以導入DOT的子集,產生一個可編輯的文檔。 (結果卻無法回輸到DOT。)

ZGRViewer,一個GraphViz的/ DOT查看器鏈接

VizierFX中,縮放圖形渲染庫鏈接

Gephi - 交互式可視化和勘探平臺各種網絡和複雜系統,動態和分層圖形

因此,這些程序中的任何一個都能夠可視化你的樹。

+0

我已經使用graphviz,但無法將其顯示爲圖像。它只是將它寫入.dot文件。我已經嘗試將ti更改爲pdf,但似乎無法使其工作。 –

+0

我相信這應該只是寫入.dot文件。然後您必須使用列出的應用程序之一來查看.dot文件。我個人喜歡格西。 –

0

我用條形圖做了一個可視化。第一個圖表示類的分佈。第一個標題代表第一個分裂標準。所有滿足這個標準的數據都會導致左下方的子圖。如果不是,則右圖是結果。因此,所有標題都表示下一次拆分的拆分標準。

百分比是來自初始分佈的值。因此,通過查看百分比,可以容易地從初始數量的數據中獲得多少分割後剩下的數據。

注意,如果你設置MAX_DEPTH高,這將需要大量的次要情節的(MAX_DEPTH,2 ^深度)

Tree visualization using bar plots

代碼:

def give_nodes(nodes,amount_of_branches,left,right): 
    amount_of_branches*=2 
    nodes_splits=[] 
    for node in nodes: 
     nodes_splits.append(left[node]) 
     nodes_splits.append(right[node]) 
    return (nodes_splits,amount_of_branches) 

def plot_tree(tree, feature_names): 
    from matplotlib import gridspec 
    import matplotlib.pyplot as plt 
    from matplotlib import rc 
    import pylab 

    color = plt.cm.coolwarm(np.linspace(1,0,len(feature_names))) 

    plt.rc('text', usetex=True) 
    plt.rc('font', family='sans-serif') 
    plt.rc('font', size=14) 

    params = {'legend.fontsize': 20, 
      'axes.labelsize': 20, 
      'axes.titlesize':25, 
      'xtick.labelsize':20, 
      'ytick.labelsize':20} 
    plt.rcParams.update(params) 

    max_depth=tree.max_depth 
    left  = tree.tree_.children_left 
    right  = tree.tree_.children_right 
    threshold = tree.tree_.threshold 
    features = [feature_names[i] for i in tree.tree_.feature] 
    value = tree.tree_.value 

    fig = plt.figure(figsize=(3*2**max_depth,2*2**max_depth)) 
    gs = gridspec.GridSpec(max_depth, 2**max_depth) 
    plt.subplots_adjust(hspace = 0.6, wspace=0.8) 

    # All data 
    amount_of_branches=1 
    nodes=[0] 
    normalize=np.sum(value[0][0]) 

    for i,node in enumerate(nodes): 
     ax=fig.add_subplot(gs[0,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches]) 
     ax.set_title(features[node]+"$<= "+str(threshold[node])+"$") 
     if(i==0): ax.set_ylabel(r'$\%$') 
     ind=np.arange(1,len(value[node][0])+1,1) 
     width=0.2 
     bars= (np.array(value[node][0])/normalize)*100 
     plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0) 
     plt.xticks(ind, [int(i) for i in ind-1]) 
     pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2)) 

    # Splits 
    for j in range(1,max_depth): 
     nodes,amount_of_branches=give_nodes(nodes,amount_of_branches,left,right) 
     for i,node in enumerate(nodes): 
      ax=fig.add_subplot(gs[j,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches]) 
      ax.set_title(features[node]+"$<= "+str(threshold[node])+"$") 
      if(i==0): ax.set_ylabel(r'$\%$') 
      ind=np.arange(1,len(value[node][0])+1,1) 
      width=0.2 
      bars= (np.array(value[node][0])/normalize)*100 
      plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0) 
      plt.xticks(ind, [int(i) for i in ind-1]) 
      pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2)) 


    plt.tight_layout() 
    return fig 

例子:

X=[] 
Y=[] 
amount_of_labels=5 
feature_names=[ '$x_1$','$x_2$','$x_3$','$x_4$','$x_5$'] 
for i in range(200): 
    X.append([np.random.normal(),np.random.randint(0,100),np.random.uniform(200,500) ]) 
    Y.append(np.random.randint(0,amount_of_labels)) 

clf = tree.DecisionTreeClassifier(criterion='entropy',max_depth=4) 
clf = clf.fit(X,Y) 
fig=plot_tree(clf, feature_names)