2016-08-13 26 views
1

我正在從CSV文件中訓練和保存模型。 這一切都是奧凱。保存模型後,我試圖加載和使用新數據保存的模型,但它不起作用。
重新加載Spark模型似乎不工作

什麼問題?

培訓Java文件

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh"); 
      SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate(); 
      JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classifications1.csv").javaRDD() 
         .map(new Function<String, Cobj>() { 
           @Override 
           public Cobj call(String line) throws Exception { 
            String[] parts = line.split(","); 
            Cobj c = new Cobj(); 
            c.setClassName(parts[1].trim()); 
            c.setProductName(parts[0].trim());         
            return c; 
           } 
         }); 

      Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);       

      //StringIndexer 
      StringIndexer classIndexer = new StringIndexer() 
         .setHandleInvalid("skip") 
         .setInputCol("className") 
         .setOutputCol("label"); 
      StringIndexerModel classIndexerModel=classIndexer.fit(mainDataset); 

      //Tokenizer 
      Tokenizer tokenizer = new Tokenizer()         
         .setInputCol("productName")      
         .setOutputCol("words");    

      //HashingTF 
      HashingTF hashingTF = new HashingTF() 
        .setInputCol(tokenizer.getOutputCol()) 
        .setOutputCol("features"); 

      DecisionTreeClassifier decisionClassifier = new DecisionTreeClassifier()      
        .setLabelCol("label") 
        .setFeaturesCol("features"); 

      Pipeline pipeline = new Pipeline() 
        .setStages(new PipelineStage[] {classIndexer,tokenizer,hashingTF,decisionClassifier}); 

     Dataset<Row>[] splits = mainDataset.randomSplit(new double[]{0.8, 0.2}); 
     Dataset<Row> train = splits[0]; 
     Dataset<Row> test = splits[1]; 

     PipelineModel pipelineModel = pipeline.fit(train); 

     Dataset<Row> result = pipelineModel.transform(test);   
     pipelineModel.write().overwrite().save(savePath+"DecisionTreeClassificationModel"); 

     IndexToString labelConverter = new IndexToString() 
        .setInputCol("prediction") 
        .setOutputCol("PredictedClassName")      
        .setLabels(classIndexerModel.labels()); 
     result=labelConverter.transform(result); 
     result.show(num,false); 
     Dataset<Row> predictionAndLabels = result.select("prediction", "label"); 
     MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() 
     .setMetricName("accuracy"); 
     System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels)); 

輸出:

+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
|className     |productName         |label|words             |features                       |rawPrediction  |probability   |prediction|PredictedClassName  | 
 
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
|Apple iPhone 6S 16GB  |Apple IPHONE 6S 16GB SGAY Telefon   |2.0 |[apple, iphone, 6s, 16gb, sgay, telefon]    |(262144,[27536,56559,169565,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0])      |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S 16GB  |Apple iPhone 6S 16 GB Space Gray MKQJ2TU/A |2.0 |[apple, iphone, 6s, 16, gb, space, gray, mkqj2tu/a] |(262144,[10879,56559,95900,139131,175329,175778,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S 16GB  |iPhone 6s 16GB        |2.0 |[iphone, 6s, 16gb]         |(262144,[27536,56559,210029],[1.0,1.0,1.0])              |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S Plus 128GB|Apple IPHONE 6S PLUS 128GB SG Telefon  |4.0 |[apple, iphone, 6s, plus, 128gb, sg, telefon]   |(262144,[56559,99916,137263,175839,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0,1.0])   |[0.0,0.0,0.0,0.0,2.0]|[0.0,0.0,0.0,0.0,1.0]|4.0  |Apple iPhone 6S Plus 128GB| 
 
|Apple iPhone 6S Plus 16GB |Iphone 6S Plus 16GB SpaceGray - Apple Türkiye|1.0 |[iphone, 6s, plus, 16gb, spacegray, -, apple, türkiye]|(262144,[27536,45531,46750,56559,59104,99916,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]) |[0.0,5.0,0.0,0.0,0.0]|[0.0,1.0,0.0,0.0,0.0]|1.0  |Apple iPhone 6S Plus 16GB | 
 
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
Accuracy = 1.0

加載Java文件

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh"); 
      SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate(); 
      JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classificationsTest.csv").javaRDD() 
         .map(new Function<String, Cobj>() { 
           @Override 
           public Cobj call(String line) throws Exception { 
            String[] parts = line.split(","); 
            Cobj c = new Cobj(); 
            c.setClassName("?"); 
            c.setProductName(parts[0].trim()); 
            return c; 
           } 
         }); 

      Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class); 
      mainDataset.show(100,false); 

      PipelineModel pipelineModel = PipelineModel.load(savePath+"DecisionTreeClassificationModel"); 

      Dataset<Row> result = pipelineModel.transform(mainDataset); 

      result.show(100,false); 

輸出:

+---------+-----------+-----+-----+--------+-------------+-----------+----------+ 
 
|className|productName|label|words|features|rawPrediction|probability|prediction| 
 
+---------+-----------+-----+-----+--------+-------------+-----------+----------+ 
 
+---------+-----------+-----+-----+--------+-------------+-----------+----------+

回答

0

我從管線上拆下StringIndexer並保存爲 「StringIndexer」。 在第二個文件中;加載管道之後,我加載了StringIndexer將其轉換爲預測標籤。