2016-04-17 73 views
0

我想對spark.ml.classification.LogisticRegression的scala代碼進行一些修改,而不必重新構建整個Spark。 由於我們可以將jar文件附加到spark-submit或pySpark的執行中。是否可以編譯LogisticRegression.java的修改副本並覆蓋Spark的默認方法,或者至少創建新方法?謝謝。Spark:覆蓋庫方法

回答

2

創建一個新類擴展org.apache.spark.ml.classification.LogisticRegression,並覆蓋相應的方法而不修改源代碼應該工作。

class CustomLogisticRegression 
    extends 
    LogisticRegression { 
    override def toString(): String = "This is overridden Logistic Regression Class" 
} 

運行Logistic迴歸與新CustomLogisticRegression

val data = sqlCtx.createDataFrame(MLUtils.loadLibSVMFile(sc, "/opt/spark/spark-1.5.2-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt")) 

val customLR = new CustomLogisticRegression() 
    .setMaxIter(10) 
    .setRegParam(0.3) 
    .setElasticNetParam(0.8) 

val customLRModel = customLR.fit(data) 

val originalLR = new LogisticRegression() 
    .setMaxIter(10) 
    .setRegParam(0.3) 
    .setElasticNetParam(0.8) 

val originalLRModel = originalLR.fit(data) 

// Print the intercept for logistic regression 
println(s"Custom Class's Intercept: ${customLRModel.intercept}") 
println(s"Original Class's Intercept: ${originalLRModel.intercept}") 
println(customLR.toString()) 
println(originalLR.toString()) 

輸出:

Custom Class's Intercept: 0.22456315961250317 
Original Class's Intercept: 0.22456315961250317 
This is overridden Logistic Regression Class 
logreg_1cd811a145d7 
+0

太感謝你了,這是真正有用的! –