2017-01-17 80 views
1

火花(2.1.0),我用了一個CrossValidator來訓練RandomForestRegressor得到MAXDEPTH,使用ParamGridBuildermaxDepthnumTrees如何從星火RandomForestRegressionModel

paramGrid = ParamGridBuilder() \ 
    .addGrid(rf.maxDepth, [2, 4, 6, 8, 10]) \ 
    .addGrid(rf.numTrees, [10, 20, 40, 50]) \ 
    .build() 

訓練結束後,我能得到樹的最佳數量:

regressor = cvModel.bestModel.stages[len(cvModel.bestModel.stages) - 1] 

print(regressor.getNumTrees) 

但我不能解決如何獲得最佳maxDepth。我讀過documentation,我看不到我在想什麼。

我注意到,我可以通過所有的樹遍歷,找到每個人的深度,如

regressor.trees[0].depth 

這好像我失去了一些東西,但。

回答

2

不幸的是,在Spark 2.3之前的PySpark RandomForestRegressionModel,與Scala對應的不同,它不存儲上游EstimatorParams,但是您應該能夠直接從JVM對象中檢索它。用一個簡單的猴子補丁:

from pyspark.ml.regression import RandomForestRegressionModel 

RandomForestRegressionModel.getMaxDepth = (
    lambda self: self._java_obj.getMaxDepth() 
) 

,您可以:

cvModel.bestModel.stages[-1].getMaxDepth() 
0

就更簡單了,只需撥打

cvModel.bestModel.stages[-1]._java_obj.getMaxDepth() 

正如@ user6910411解釋,你得到的bestModel,調用的JVM對象此模型並使用JVM對象中的getMaxDepth()提取參數。 其他參數類似。