2017-08-13 27 views
0

我提出了一個樹分類命名的模型,並試圖用出口graphviz的功能是這樣的:樹分類到graphviz的錯誤

export_graphviz(decision_tree=model, 
        out_file='NT_model.dot', 
        feature_names=X_train.columns, 
        class_names=model.classes_, 
        leaves_parallel=True, 
        filled=True, 
        rotate=False, 
        rounded=True) 

出於某種原因,我跑曾提出這樣的例外:


TypeError         Traceback (most recent call last) 
<ipython-input-298-40fe56bb0c85> in <module>() 
     6      filled=True, 
     7      rotate=False, 
----> 8      rounded=True) 

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- 
packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, 
max_depth, feature_names, class_names, label, filled, leaves_parallel, 
impurity, node_ids, proportion, rotate, rounded, special_characters) 
    431    recurse(decision_tree, 0, criterion="impurity") 
    432   else: 
--> 433    recurse(decision_tree.tree_, 0, 
criterion=decision_tree.criterion) 
    434 
    435   # If required, draw leaf nodes at same depth as each other 

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- 
packages\sklearn\tree\export.py in recurse(tree, node_id, criterion, parent, 
depth) 
    319    out_file.write('%d [label=%s' 
    320       % (node_id, 
--> 321        node_to_str(tree, node_id, 
criterion))) 
    322 
    323    if filled: 

C:\Users\yonatanv\AppData\Local\Continuum\Anaconda3\lib\site- 
packages\sklearn\tree\export.py in node_to_str(tree, node_id, criterion) 
    289           np.argmax(value), 
    290           characters[2]) 
--> 291    node_string += class_name 
    292 
    293   # Clean up any trailing newlines 

TypeError: ufunc 'add' did not contain a loop with signature matching types 
dtype('<U90') dtype('<U90') dtype('<U90') 

我對可視化超參數是那些:

print(model) 
DecisionTreeClassifier(class_weight={1.0: 10, 0.0: 1}, criterion='gini', 
     max_depth=7, max_features=None, max_leaf_nodes=None, 
     min_impurity_split=1e-07, min_samples_leaf=50, 
     min_samples_split=2, min_weight_fraction_leaf=0.0, 
     presort=False, random_state=0, splitter='best') 

print(model.classes_) 
[ 0. , 1. ] 

幫助將不勝感激!

+0

確保您使用的是scikit-learn的更新版本。如果仍然面臨這個問題,那麼你需要提供更多的細節來幫助我們。從錯誤的完整堆棧跟蹤開始。然後提供您用來訓練'model'的代碼以及一些樣本數據。 –

+0

我正在使用安裝在anaconda3上的版本 –

+0

爲我的問題添加了更多的處理,以便通知我! –

回答

0

正如您在documentation of export_graphviz中所看到的,參數class_names適用於字符串,不適用於float或int。

class_names:字符串列表,布爾或無,可選的(默認=無)

嘗試將它們傳遞在export_graphviz之前轉換到model.classes_字符串列表。

嘗試class_names=['0', '1']class_names=['0.0', '1.0']致電export_graphviz()

對於一個更通用的解決方案,使用方法:

class_names=[str(x) for x in model.classes_]

但是有沒有要傳遞的浮點值作爲ymodel.fit()具體原因是什麼?因爲這在分類任務中大多不需要。你是否真的有y這樣的標籤,或者你是否在裝配模型之前將字符串標籤轉換爲數字?

+0

這裏的y標籤最初是數字,作爲二進制變量 –

+0

好了。沒關係。 –