火花(2.1.0),我用了一個CrossValidator
來訓練RandomForestRegressor
得到MAXDEPTH,使用ParamGridBuilder
爲maxDepth
和numTrees
:如何從星火RandomForestRegressionModel
paramGrid = ParamGridBuilder() \
.addGrid(rf.maxDepth, [2, 4, 6, 8, 10]) \
.addGrid(rf.numTrees, [10, 20, 40, 50]) \
.build()
訓練結束後,我能得到樹的最佳數量:
regressor = cvModel.bestModel.stages[len(cvModel.bestModel.stages) - 1]
print(regressor.getNumTrees)
但我不能解決如何獲得最佳maxDepth。我讀過documentation,我看不到我在想什麼。
我注意到,我可以通過所有的樹遍歷,找到每個人的深度,如
regressor.trees[0].depth
這好像我失去了一些東西,但。