0

擴大在之前討論的節點:不是二進制區別 Changing colors for decision tree plot created using export graphviz彩色樹使用的graphviz class_names

我怎麼會色樹基地的節點上的統治階級(虹膜種), ?這應該需要將iris.target_names(描述​​該類的字符串)和iris.target(該類)組合在一起。

import pydotplus 
from sklearn.datasets import load_iris 
from sklearn import tree 
import collections 

clf = tree.DecisionTreeClassifier(random_state=42) 
iris = load_iris() 

clf = clf.fit(iris.data, iris.target) 

dot_data = tree.export_graphviz(clf, out_file=None, 
           feature_names=iris.feature_names, 
           class_names=iris.target_names, 
           filled=True, rounded=True, 
           special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data) 
nodes = graph.get_node_list() 
edges = graph.get_edge_list() 

colors = ('brown', 'forestgreen') 
edges = collections.defaultdict(list) 

for edge in graph.get_edge_list(): 
    edges[edge.get_source()].append(int(edge.get_destination())) 

for edge in edges: 
    edges[edge].sort()  
    for i in range(2): 
     dest = graph.get_node(str(edges[edge][i]))[0] 
     dest.set_fillcolor(colors[i]) 

graph.write_png('tree.png') 

回答

0

從示例代碼看起來那麼熟悉,因此易於修改:)

對於每個節點Graphviz告訴我們如何從每個組中,我們有,即許多樣品,如果它是一個混合人羣或樹決定。我們可以提取此信息並用於獲取顏色。

values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')] 

或者您可以在GraphViz節點返回映射到sklearn節點:

values = clf.tree_.value[int(node.get_name())][0] 

我們只有3班,讓每個人都有自己的顏色(紅,綠,藍),混血人種根據其分佈獲取混合顏色。

values = [int(255 * v/sum(values)) for v in values] 
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]) 

enter image description here

我們現在可以看到分離很好,它得到第二類,我們有,同爲藍色和第三類的更多的更環保。


import pydotplus 
from sklearn.datasets import load_iris 
from sklearn import tree 

clf = tree.DecisionTreeClassifier(random_state=42) 
iris = load_iris() 

clf = clf.fit(iris.data, iris.target) 

dot_data = tree.export_graphviz(clf, 
           feature_names=iris.feature_names, 
           out_file=None, 
           filled=True, 
           rounded=True, 
           special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data) 
nodes = graph.get_node_list() 

for node in nodes: 
    if node.get_label(): 
     values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')] 
     values = [int(255 * v/sum(values)) for v in values] 
     color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]) 
     node.set_fillcolor(color) 

graph.write_png('colored_tree.png') 

超過3類,其顏色只有最終節點的一般解。

colors = ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white') 

for node in nodes: 
    if node.get_name() not in ('node', 'edge'): 
     values = clf.tree_.value[int(node.get_name())][0] 
     #color only nodes where only one class is present 
     if max(values) == sum(values):  
      node.set_fillcolor(colors[numpy.argmax(values)]) 
     #mixed nodes get the default color 
     else: 
      node.set_fillcolor(colors[-1]) 

enter image description here

+0

我個人的問題有四類。你會如何概括這個n班? – MyopicVisage

+0

我曾打算使用:colors =('lightblue','lightyellow','lightgreen','lightred'),然後在它們之間進行插值。 – MyopicVisage

+0

@MyopicVisage:這是具有挑戰性的:)我個人更喜歡只有最後的節點着色,否則它會變成聖誕樹。我會多考慮一下。 –