2017-08-26 34 views
2

我嘗試使用參數在字符串中的DecisionTreeClassifier。爲DecisionTreeClassifier傳遞參數時出錯

print d # d= 'max_depth=100' 
clf = DecisionTreeClassifier(d) 
clf.fit(X[:3000,], labels[:3000]) 

我在這種情況下得到低於錯誤。如果我使用clf = DecisionTreeClassifier(max_depth=100)它工作正常。

Traceback (most recent call last): 
    File "train.py", line 120, in <module> 
    grid_search_generalized(X, labels, {"max_depth":[i for i in range(100, 200)]}) 
    File "train.py", line 51, in grid_search_generalized 
    clf.fit(X[:3000,], labels[:3000]) 
    File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 790, in fit 
    X_idx_sorted=X_idx_sorted) 
    File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 326, in fit 
    criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, 
KeyError: 'max_depth=100' 

回答

1

您將參數作爲字符串對象傳遞,而不是作爲可選參數傳遞。
如果你真的調用這個字符串的構造函數,你可以使用此代碼:

arg = dict([d.split("=")]) 
clf = DecisionTreeClassifier(**arg) 

你可以閱讀更多有關參數在這個環節
Passing a dictionary to a function in python as keyword parameters

+0

'回溯(最近通話拆包最後): 文件 「train.py」,線121,在 grid_search_generalized(X,標籤,{ 「MAX_DEPTH」:[100]}) 文件 「train.py」,線路52,在grid_search_generalized clf.fit (X [:3000,],標籤[:3000]) 文件「/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py」,第790行,擬合 X_idx_sorted = X_idx_sorted) 文件「/usr/local/lib/python2.7 /dist-packages/sklearn/tree/tree.py「,第326行,在fit criterion = CRITERIA_CLF [self.criterion](self.n_outputs_, TypeError:unhashable type:'dict''現在得到這個錯誤。 – tarun14110