2017-05-13 52 views
2

我在我的java代碼中使用了weka中的LibSVM。我正在嘗試做一個迴歸。下面是我的代碼,Java,weka LibSVM不能正確預測

public static void predict() { 

    try { 
     DataSource sourcePref1 = new DataSource("train_pref2new.arff"); 
     Instances trainData = sourcePref1.getDataSet(); 

     DataSource sourcePref2 = new DataSource("testDatanew.arff"); 
     Instances testData = sourcePref2.getDataSet(); 

     if (trainData.classIndex() == -1) { 
      trainData.setClassIndex(trainData.numAttributes() - 2); 
     } 

     if (testData.classIndex() == -1) { 
      testData.setClassIndex(testData.numAttributes() - 2); 
     } 

     LibSVM svm1 = new LibSVM(); 

     String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1"); 
     String[] optionsArray = options.split(" "); 
     svm1.setOptions(optionsArray); 

     svm1.buildClassifier(trainData); 

     for (int i = 0; i < testData.numInstances(); i++) { 

      double pref1 = svm1.classifyInstance(testData.instance(i));     
      System.out.println("predicted value : " + pref1); 

     } 

    } catch (Exception ex) { 
     Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex); 
    } 
} 

但預測值我從這個代碼得到的是比預測值,我使用的Weka GUI越來越不同。

示例: 下面是我爲java代碼和weka GUI提供的單個測試數據。

的Java代碼的預測值作爲1.9064516129032265而Weka的GUI的預測值是10.043。我爲Java代碼和Weka GUI使用相同的訓練數據集和相同的參數。

我希望你明白我的問題。有人告訴我我的代碼有什麼問題嗎?

回答

2

您正在使用錯誤的算法執行SVM迴歸。 LibSVM用於分類。你想要的那個是SMOreg,這是一個用於迴歸的特定SVM。

下面是一個完整的示例,顯示如何使用SMOreg使用Weka Explorer GUI以及Java API。對於數據,我將使用Weka發行版附帶的cpu.arff數據文件。請注意,我將使用此文件進行培訓和測試,但理想情況下,您將擁有單獨的數據集。

使用Weka的瀏覽器GUI

  1. 打開資源管理器WEKA GUI,單擊Preprocess選項卡上,單擊Open File,然後打開cpu.arff文件應該在你的Weka的分佈。在我的系統上,該文件在weka-3-8-1/data/cpu.arff之下。資源管理器窗口應如下所示:

Weka Explorer - Choosing the file

  • 點擊Classify標籤。它應該被稱爲「預測」,因爲您可以在這裏進行分類和迴歸。在Classifier下,點擊Choose,然後選擇wekaclassifiersfunctionsSMOreg,如下所示。
  • Weka Explorer - Choosing the regression algorithm

  • 現在生成迴歸模型和評估它。在Test Options下選擇Use training set,這樣我們的訓練集也用於測試(正如我在上面提到的,這不是理想的方法)。現在按Start,結果應該如下所示:
  • Weka Explorer - Results from testing

    記下RMSE值(74.5996)的。我們將在Java代碼實現中重新討論這一點。

    使用Java API

    下面是使用Weka的API來複制在Weka的瀏覽器GUI前面顯示的結果一個完整的Java程序。

    import weka.classifiers.functions.SMOreg; 
    import weka.classifiers.Evaluation; 
    import weka.core.Instance; 
    import weka.core.Instances; 
    import weka.core.converters.ConverterUtils.DataSource; 
    
    public class Tester { 
    
        /** 
        * Builds a regression model using SMOreg, the SVM for regression, and 
        * evaluates it with the Evalution framework. 
        */ 
        public void buildAndEvaluate(String trainingArff, String testArff) throws Exception { 
    
         System.out.printf("buildAndEvaluate() called.\n"); 
    
         // Load the training and test instances. 
         Instances trainingInstances = DataSource.read(trainingArff); 
         Instances testInstances = DataSource.read(testArff); 
    
         // Set the true value to be the last field in each instance. 
         trainingInstances.setClassIndex(trainingInstances.numAttributes()-1); 
         testInstances.setClassIndex(testInstances.numAttributes()-1); 
    
         // Build the SMOregression model. 
         SMOreg smo = new SMOreg(); 
         smo.buildClassifier(trainingInstances); 
    
         // Use Weka's evaluation framework. 
         Evaluation eval = new Evaluation(trainingInstances); 
         eval.evaluateModel(smo, testInstances); 
    
         // Print the options that were used in the ML algorithm. 
         String[] options = smo.getOptions(); 
         System.out.printf("Options used:\n"); 
         for (String option : options) { 
          System.out.printf("%s ", option); 
         } 
         System.out.printf("\n\n"); 
    
         // Print the algorithm details. 
         System.out.printf("Algorithm:\n %s\n", smo.toString()); 
    
         // Print the evaluation results. 
         System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false)); 
        } 
    
        /** 
        * Builds a regression model using SMOreg, the SVM for regression, and 
        * tests each data instance individually to compute RMSE. 
        */ 
        public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception { 
    
         System.out.printf("buildAndTestEachInstance() called.\n"); 
    
         // Load the training and test instances. 
         Instances trainingInstances = DataSource.read(trainingArff); 
         Instances testInstances = DataSource.read(testArff); 
    
         // Set the true value to be the last field in each instance. 
         trainingInstances.setClassIndex(trainingInstances.numAttributes()-1); 
         testInstances.setClassIndex(testInstances.numAttributes()-1); 
    
         // Build the SMOregression model. 
         SMOreg smo = new SMOreg(); 
         smo.buildClassifier(trainingInstances); 
    
         int numTestInstances = testInstances.numInstances(); 
    
         // This variable accumulates the squared error from each test instance. 
         double sumOfSquaredError = 0.0; 
    
         // Loop over each test instance. 
         for (int i = 0; i < numTestInstances; i++) { 
    
          Instance instance = testInstances.instance(i); 
    
          double trueValue = instance.value(testInstances.classIndex()); 
          double predictedValue = smo.classifyInstance(instance); 
    
          // Uncomment the next line to see every prediction on the test instances. 
          //System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue); 
    
          double error = trueValue - predictedValue; 
          sumOfSquaredError += (error * error); 
         } 
    
         // Print the RMSE results. 
         double rmse = Math.sqrt(sumOfSquaredError/numTestInstances); 
         System.out.printf("RMSE = %10.5f\n", rmse); 
        } 
    
        public static void main(String argv[]) throws Exception { 
    
         Tester classify = new Tester(); 
         classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff"); 
         classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff"); 
        } 
    } 
    

    我已經寫了訓練的SMOreg模型和訓練數據運行預測評估模型兩種功能。

    • buildAndEvaluate()使用的Weka Evaluation框架運行一系列測試,以得到完全相同的 結果作爲資源管理GUI評估模型。值得注意的是,它產生了一個RMSE值。

    • buildAndTestEachInstance()評估由明確 遍歷每個測試實例,進行預測,計算 誤差,並且計算總的RMSE模型。請注意,此RMSE匹配 從buildAndEvaluate()開始的一個,後者與Explorer GUI中的 匹配。

    下面是編譯和運行程序的結果。

    prompt> javac -cp weka.jar Tester.java 
    
    prompt> java -cp .:weka.jar Tester 
    
    buildAndEvaluate() called. 
    Options used: 
    -C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007 
    
    Algorithm: 
    SMOreg 
    
    weights (not support vectors): 
    +  0.01 * (normalized) MYCT 
    +  0.4321 * (normalized) MMIN 
    +  0.1847 * (normalized) MMAX 
    +  0.1175 * (normalized) CACH 
    +  0.0973 * (normalized) CHMIN 
    +  0.0235 * (normalized) CHMAX 
    -  0.0168 
    
    
    
    Number of kernel evaluations: 21945 (93.081% cached) 
    
    Results 
    ===== 
    
    Correlation coefficient     0.9044 
    Mean absolute error      31.7392 
    Root mean squared error     74.5996 
    Relative absolute error     33.0908 % 
    Root relative squared error    46.4953 % 
    Total Number of Instances    209  
    
    buildAndTestEachInstance() called. 
    RMSE = 74.59964 
    
    +0

    其實Libsvm有2個SVM類型的迴歸,nu-SVR和epsilon-SVR。通過定義算法的-S參數,我可以決定使用哪種svm類型。在我的代碼中,我使用了epsilon-SVR(-S 3)。但是你的代碼確實幫我找到了我的代碼中的錯誤。 setClassIndex在我的代碼中是錯誤的。我用你的代碼,它的工作。非常感謝您的幫助。 – udi