2017-07-09 36 views
0

我在Apache Spark ML(版本2.1.0)中使用NaiveBayes多分類分類器來預測某些文本類別。Spark ML將預測標籤轉換爲字符串,無需培訓DataFrame

問題是如何將預測標籤(0.0,1.0,2.0)轉換爲沒有經過培訓的DataFrame的字符串。

我知道IndexToString可以使用,但它的唯一幫助,如果訓練和預測都在同一時間。但是,在我的情況下,它的獨立工作。

代碼看起來像
1)TrainingModel.scala:訓練模型並將模型保存在文件中。
2)CategoryPrediction.scala:從文件中加載訓練好的模型,並對測試數據進行預測。

請建議解決方案:

TrainingModel.scala

val trainData: Dataset[LabeledRecord] = spark.read.option("inferSchema", "false") 
    .schema(schema).csv("trainingdata1.csv").as[LabeledRecord] 

val labelIndexer = new StringIndexer().setInputCol("category").setOutputCol("label").fit(trainData).setHandleInvalid("skip") 

val tokenizer = new RegexTokenizer().setInputCol("text").setOutputCol("words") 

val hashingTF = new HashingTF() 
    .setInputCol("words") 
    .setOutputCol("features") 
    .setNumFeatures(1000) 

val rf = new NaiveBayes().setLabelCol("label").setFeaturesCol("features").setModelType("multinomial") 

val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf)) 

val model = pipeline.fit(trainData) 

model.write.overwrite().save("naivebayesmodel"); 

CategoryPrediction.scala

val testData: Dataset[PredictLabeledRecord] = spark.read.option("inferSchema", "false") 
     .schema(predictSchema).csv("testingdata.csv").as[PredictLabeledRecord] 

val model = PipelineModel.load("naivebayesmodel") 

val predictions = model.transform(testData) 

//  val labelConverter = new IndexToString() 
//  .setInputCol("prediction") 
//  .setOutputCol("predictedLabelString") 
//  .setLabels(trainDataFrameIndexer.labels)  

predictions.select("prediction", "text").show(false) 

trainingdata1.csv

category,text 
Drama,"a b c d e spark" 
Action,"b d" 
Horror,"spark f g h" 
Thriller,"hadoop mapreduce" 

testingdata.csv

text 
"a b c d e spark" 
"spark f g h" 

回答

-1

添加一個轉換器,將在您的管道預測類別轉換回你的標籤,像這樣:

val categoryConverter = new IndexToString() 
    .setInputCol("prediction") 
    .setOutputCol("category") 
    .setLabels(labelIndexer.labels) 

val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf, categoryConverter)) 

這將需要預測並使用您的labelIndexer將其轉換回標籤。

+0

感謝帕斯卡你的答覆工作正常。 – user657816