2015-03-24 28 views
1

我有以下兩種使用Apache Spark中的對象的方法。減少兩個Scala方法,只有一個不同對象類型

def SVMModelScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = { 
    val model = SVMModel.load(sc, modelFileName) 

    val scoreAndLabels = 
     MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point => 
     val score = model.predict(point.features) 
     (score, point.label) 
     } 
    return scoreAndLabels 
    } 

    def DecisionTreeScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = { 
    val model = DecisionTreeModel.load(sc, modelFileName) 

    val scoreAndLabels = 
     MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point => 
     val score = model.predict(point.features) 
     (score, point.label) 
     } 
    return scoreAndLabels 
    } 

我以前的嘗試合併這些函數導致錯誤環繞model.predict。

有沒有一種方法可以使用模型作爲在Scala中弱類型的參數?

回答

2

免責聲明 - 我從來沒有使用Apache Spark。

它看起來像我這兩種方法之間的唯一區別是model實例化的方式。這是不幸的是,這兩個model情況下,實際上並不共享一個共同的特點,提供predict(...)但我們仍然可以通過拉出改變部分這項工作 - 在scorer

def scoreWith(sc: SparkContext, scoringDataset: String)(scorer: (Vector)=>Double): RDD[(Double, Double)] = { 
    MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point => 
    val score = scorer(point.features) 
    (score, point.label) 
    } 
} 

現在,我們可以得到以前的功能與:

def svmScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) = 
    scoreWith(sc: SparkContext, scoringDataset:String)(SVMModel.load(sc, modelFileName).predict) 

def dtScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) = 
    scoreWith(sc: SparkContext, scoringDataset:String)(DecisionTreeModel.load(sc, modelFileName).predict) 
+0

非常有希望...但我的IDE(Eclipse)拋出一個錯誤的'矢量'在你提供的代碼的第一行(我試圖調試,但萬一你可能知道.. ) – 2015-03-25 01:09:14

+0

「type Vector需要類型參數」 – 2015-03-25 01:09:48

+0

您的IDE已導入Scala的'Vector [T]'collec而不是Spark的:[org.apache.spark.mllib.linalg.Vector](http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib。 linalg.Vector) – millhouse 2015-03-25 01:44:35

相關問題