2016-09-15 31 views
1

我已經設置LightSIDE插件,並且可以正常運行,但我不知道爲什麼我無法將我的數據保存到空文件?這是我做的一個簡單的結構。Java LightSIDE - 如何使用LightSIDE對數據進行分類?

enter image description here

  1. 活動是需要進行分類的列表數據。
  2. 我有3個類別,每個類別都有各自的類型。
  3. 我已經用特定的單詞列表定義了每個類別。例如:食物({壽司,食物,日本},{Cap Jay,食物,中國},{慢跑,運動,跑步},...)

這就是我如何使用LightSIDE 。

public void predictSectionType(String[] sections, List<String> activityList) { 
     LightSideService currentLightsideHelper = new LightSideService(); 
     Recipe newRecipe; 

     // Initialize SIDEPlugin 
     currentLightsideHelper.initSIDEPlugin(); 

     try { 
      // Load Recipe with Extracted Features & Trained Models 
      ClassLoader myClassLoader = getClass().getClassLoader(); 
      newRecipe = ConverterControl.readFromXML(new InputStreamReader(myClassLoader.getResourceAsStream("static/lightsideTrainingResult/trainingData.xml"))); 

      // Predict Result Data 
      Recipe recipeToPredict = currentLightsideHelper.loadNewDocumentsFromCSV(sections); // DocumentList & Recipe Created 
      currentLightsideHelper.predictLabels(recipeToPredict, newRecipe); 


     } catch (FileNotFoundException e) { 
      e.printStackTrace(); 
     } catch (IOException e) { 
      e.printStackTrace(); 
     } 
    } 

我把LightSideService類作爲LightSIDE函數的Summary類。

public class LightSideService { 

    // Extract Features Parameters 
    final String featureTableName = "1Grams"; 
    final int featureThreshold = 2; 
    final String featureAnnotation = "Code"; 
    final Type featureType = Type.NOMINAL; 

    // Build Models Parameters 
    final String trainingResultName = "Bayes_1Grams"; 

    // Predict Labels Parameters 
    final String predictionColumnName = featureAnnotation + "_Prediction"; 
    final boolean showMaxScore = false; 
    final boolean showDists = true; 
    final boolean overwrite = false; 
    final boolean useEvaluation = false; 

    public DocumentListTableModel model = new DocumentListTableModel(null); 

    public Map<String, Serializable> validationSettings = new TreeMap<String, Serializable>(); 
    public Map<FeaturePlugin, Boolean> featurePlugins = new HashMap<FeaturePlugin, Boolean>(); 
    public Map<LearningPlugin, Boolean> learningPlugins = new HashMap<LearningPlugin, Boolean>(); 
    public Collection<ModelMetricPlugin> modelEvaluationPlugins = new ArrayList<ModelMetricPlugin>(); 
    public Map<WrapperPlugin, Boolean> wrapperPlugins = new HashMap<WrapperPlugin, Boolean>(); 

    // Initialize Data ================================================== 

    public void initSIDEPlugin() {    
     SIDEPlugin[] featureExtractors = PluginManager.getSIDEPluginArrayByType("feature_hit_extractor"); 
     boolean selected = true; 
     for (SIDEPlugin fe : featureExtractors) { 
      featurePlugins.put((FeaturePlugin) fe, selected); 
      selected = false; 
     } 
     SIDEPlugin[] learners = PluginManager.getSIDEPluginArrayByType("model_builder"); 
     for (SIDEPlugin le : learners) { 
      learningPlugins.put((LearningPlugin) le, true); 
     } 
     SIDEPlugin[] tableEvaluations = PluginManager.getSIDEPluginArrayByType("model_evaluation"); 
     for (SIDEPlugin fe : tableEvaluations) { 
      modelEvaluationPlugins.add((ModelMetricPlugin) fe); 
     } 
     SIDEPlugin[] wrappers = PluginManager.getSIDEPluginArrayByType("learning_wrapper"); 
     for (SIDEPlugin wr : wrappers) { 
      wrapperPlugins.put((WrapperPlugin) wr, false); 
     } 
    } 

    //Used to Train Models, adjust parameters according to model 
    public void initValidationSettings(Recipe currentRecipe) { 
     validationSettings.put("testRecipe", currentRecipe); 
     validationSettings.put("testSet", currentRecipe.getDocumentList()); 
     validationSettings.put("annotation", "Age"); 
     validationSettings.put("type", "CV"); 
     validationSettings.put("foldMethod", "AUTO"); 
     validationSettings.put("numFolds", 10); 
     validationSettings.put("source", "RANDOM"); 
     validationSettings.put("test", "true"); 
    } 

    // Load CSV Doc ================================================== 

    public Recipe loadNewDocumentsFromCSV(String filePath) { 
     DocumentList testDocs; 

     testDocs = chooseDocumentList(filePath); 

     if (testDocs != null) { 
      testDocs.guessTextAndAnnotationColumns(); 
      Recipe currentRecipe = Recipe.fetchRecipe(); 
      currentRecipe.setDocumentList(testDocs); 
      return currentRecipe; 
     } 
     return null; 
    } 

    public Recipe loadNewDocumentsFromCSV(String[] rootCauseList) { 
     DocumentList testDocs; 

     testDocs = chooseDocumentList(rootCauseList); 

     if (testDocs != null) { 
      testDocs.guessTextAndAnnotationColumns(); 
      Recipe currentRecipe = Recipe.fetchRecipe(); 
      currentRecipe.setDocumentList(testDocs); 
      return currentRecipe; 
     } 
     return null; 
    } 

    protected DocumentList chooseDocumentList(String filePath) { 
     TreeSet<String> docNames = new TreeSet<String>(); 
     docNames.add(filePath); 

     try { 
      DocumentList testDocs; 

      Charset encoding = Charset.forName("UTF-8"); 

      { 
       testDocs = ImportController.makeDocumentList(docNames, encoding); 
      } 

      return testDocs; 
     } catch (FileNotFoundException e) { 
      e.printStackTrace(); 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
     return null; 
    } 

    protected DocumentList chooseDocumentList(String[] rootCauseList) { 
     try { 
      DocumentList testDocs; 

      testDocs = new DocumentList(); 
      testDocs.setName("TestData.csv"); 

      List<String> codes = new ArrayList(); 
      List<String> roots = new ArrayList(); 
      for (String s : rootCauseList) { 
       codes.add(""); 
       roots.add((s != null) ? s : ""); 
      } 

      testDocs.addAnnotation("Code", codes, false); 
      testDocs.addAnnotation("Root Cause Failure Description", roots, false); 

      return testDocs; 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
     return null; 
    } 

    // Save/Load XML ================================================== 

    public void saveRecipeToXml(Recipe currentRecipe, String filePath) { 
     File f = new File(filePath); 
     try { 
      ConverterControl.writeToXML(f, currentRecipe); 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
    } 

    public Recipe loadRecipeFromXml(String filePath) throws FileNotFoundException, IOException { 
     Recipe currentRecipe = ConverterControl.loadRecipe(filePath); 
     return currentRecipe; 
    } 

    // Extract Features ================================================== 

    public Recipe prepareBuildFeatureTable(Recipe currentRecipe) { 
     // Add Feature Plugins 
     Collection<FeaturePlugin> plugins = new TreeSet<FeaturePlugin>(); 
     for (FeaturePlugin plugin : featurePlugins.keySet()) { 
      String pluginString = plugin.toString(); 
      if (pluginString == "Basic Features" || pluginString == "Character N-Grams") { 
       plugins.add(plugin); 
      } 
     } 

     // Generate Plugin into Recipe 
     currentRecipe = Recipe.addPluginsToRecipe(currentRecipe, plugins); 

     // Setup Plugin configurations 
     OrderedPluginMap currentOrderedPluginMap = currentRecipe.getExtractors(); 
     for (SIDEPlugin plugin : currentOrderedPluginMap.keySet()) { 
      String pluginString = plugin.toString(); 
      Map<String, String> currentConfigurations = currentOrderedPluginMap.get(plugin); 

      if (pluginString == "Basic Features") { 
       for (String s : currentConfigurations.keySet()) { 
        if (s == "Unigrams" || s == "Bigrams" || s == "Trigrams" || 
         s == "Count Occurences" || s == "Normalize N-Gram Counts" || 
         s == "Stem N-Grams" || s == "Skip Stopwords in N-Grams") { 
         currentConfigurations.put(s, "true"); 
        } else { 
         currentConfigurations.put(s, "false"); 
        } 
       } 
      } else if (pluginString == "Character N-Grams") { 
       for (String s : currentConfigurations.keySet()) { 
        if (s == "Include Punctuation") { 
         currentConfigurations.put(s, "true"); 
        } else if (s == "minGram") { 
         currentConfigurations.put(s, "3"); 
        } else if (s == "maxGram") { 
         currentConfigurations.put(s, "4"); 
        } 
       } 
       currentConfigurations.put("Extract Only Within Words", "true"); 
      } 
     } 

     // Build FeatureTable 
     currentRecipe = buildFeatureTable(currentRecipe, featureTableName, featureThreshold, featureAnnotation, featureType); 

     return currentRecipe; 
    } 

    protected Recipe buildFeatureTable(Recipe currentRecipe, String name, int threshold, String annotation, Type type) { 
     FeaturePlugin activeExtractor = null; 

     try { 
      Collection<FeatureHit> hits = new HashSet<FeatureHit>(); 
      for (SIDEPlugin plug : currentRecipe.getExtractors().keySet()) { 
       activeExtractor = (FeaturePlugin) plug; 
       hits.addAll(activeExtractor.extractFeatureHits(currentRecipe.getDocumentList(), currentRecipe.getExtractors().get(plug))); 
      } 

      FeatureTable ft = new FeatureTable(currentRecipe.getDocumentList(), hits, threshold, annotation, type); 
      ft.setName(name); 
      currentRecipe.setFeatureTable(ft); 
     } catch (Exception e) { 
      System.err.println("Feature Extraction Failed"); 
      e.printStackTrace(); 
     } 

     return currentRecipe; 
    } 

    // Build Models ================================================== 

    public Recipe prepareBuildModel(Recipe currentRecipe) { 
     try { 
      // Get Learner Plugins 
      LearningPlugin learner = null; 
      for (LearningPlugin plugin : learningPlugins.keySet()) { 
       /* if (plugin.toString() == "Naive Bayes") */ 
       if (plugin.toString() == "Logistic Regression") { 
        learner = plugin; 
       } 
      } 

      if (Boolean.TRUE.toString().equals(validationSettings.get("test"))) { 
       if (validationSettings.get("type").equals("CV")) { 
        validationSettings.put("testSet", currentRecipe.getDocumentList()); 
       } 
      } 

      Map<String, String> settings = learner.generateConfigurationSettings(); 

      currentRecipe = Recipe.addLearnerToRecipe(currentRecipe, learner, settings); 
      currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings)); 

      for (WrapperPlugin wrap : wrapperPlugins.keySet()) { 
       if (wrapperPlugins.get(wrap)) { 
        currentRecipe.addWrapper(wrap, wrap.generateConfigurationSettings()); 
       } 
      } 

      buildModel(currentRecipe, validationSettings); 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
     return currentRecipe; 
    } 

    protected void buildModel(Recipe currentRecipe, 
      Map<String, Serializable> validationSettings) { 
     try { 
      FeatureTable currentFeatureTable = currentRecipe.getTrainingTable(); 
      if (currentRecipe != null) { 
       TrainingResult results = null; 
       /* 
       * if (validationSettings.get("type").equals("SUPPLY")) { 
       * DocumentList test = (DocumentList) 
       * validationSettings.get("testSet"); FeatureTable 
       * extractTestFeatures = prepareTestFeatureTable(currentRecipe, 
       * validationSettings, test); 
       * validationSettings.put("testFeatureTable", 
       * extractTestFeatures); 
       * 
       * // if we've already trained the exact same model, don't // do 
       * it again. Just evaluate. Recipe cached = 
       * checkForCachedModel(); if (cached != null) { results = 
       * evaluateUsingCachedModel(currentFeatureTable, 
       * extractTestFeatures, cached, currentRecipe); } } 
       */ 

       if (results == null) { 
        results = currentRecipe.getLearner().train(currentFeatureTable, currentRecipe.getLearnerSettings(), validationSettings, currentRecipe.getWrappers()); 
       } 

       if (results != null) { 
        currentRecipe.setTrainingResult(results); 
        results.setName(trainingResultName); 

        currentRecipe.setLearnerSettings(currentRecipe.getLearner().generateConfigurationSettings()); 
        currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings)); 
       } 
      } 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
    } 

    protected static FeatureTable prepareTestFeatureTable(Recipe recipe, Map<String, Serializable> validationSettings, DocumentList test) { 
     prepareDocuments(recipe, validationSettings, test); // assigns classes, annotations. 

     Collection<FeatureHit> hits = new TreeSet<FeatureHit>(); 
     OrderedPluginMap extractors = recipe.getExtractors(); 
     for (SIDEPlugin plug : extractors.keySet()) { 
      Collection<FeatureHit> extractorHits = ((FeaturePlugin) plug).extractFeatureHits(test, extractors.get(plug)); 
      hits.addAll(extractorHits); 
     } 
     FeatureTable originalTable = recipe.getTrainingTable(); 
     FeatureTable ft = new FeatureTable(test, hits, 0, originalTable.getAnnotation(), originalTable.getClassValueType()); 
     for (SIDEPlugin plug : recipe.getFilters().keySet()) { 
      ft = ((RestructurePlugin) plug).filterTestSet(originalTable, ft, recipe.getFilters().get(plug), recipe.getFilteredTable().getThreshold()); 
     } 

     ft.reconcileFeatures(originalTable.getFeatureSet()); 

     return ft; 

    } 

    protected static Map<String, Serializable> prepareDocuments(Recipe currentRecipe, Map<String, Serializable> validationSettings, DocumentList test) throws IllegalStateException { 
     DocumentList train = currentRecipe.getDocumentList(); 

     try { 
      test.setCurrentAnnotation(currentRecipe.getTrainingTable().getAnnotation(), currentRecipe.getTrainingTable().getClassValueType()); 
      test.setTextColumns(new HashSet<String>(train.getTextColumns())); 
      test.setDifferentiateTextColumns(train.getTextColumnsAreDifferentiated()); 

      Collection<String> trainColumns = train.allAnnotations().keySet(); 
      Collection<String> testColumns = test.allAnnotations().keySet(); 
      if (!testColumns.containsAll(trainColumns)) { 
       ArrayList<String> missing = new ArrayList<String>(trainColumns); 
       missing.removeAll(testColumns); 
       throw new java.lang.IllegalStateException("Test set annotations do not match training set.\nMissing columns: " + missing); 
      } 

      validationSettings.put("testSet", test); 
     } catch (Exception e) { 
      e.printStackTrace(); 
      throw new java.lang.IllegalStateException("Could not prepare test set.\n" + e.getMessage(), e); 
     } 
     return validationSettings; 
    } 

    //Predict Labels ================================================== 

    public void predictLabels(Recipe recipeToPredict, Recipe currentRecipe) { 
     DocumentList newDocs = null; 
     DocumentList originalDocs; 
     if (useEvaluation) { 
      originalDocs = recipeToPredict.getTrainingResult().getEvaluationTable().getDocumentList(); 

      TrainingResult results = currentRecipe.getTrainingResult(); 
      List<String> predictions = (List<String>) results.getPredictions(); 
      newDocs = addLabelsToDocs(predictionColumnName, showDists, overwrite, originalDocs, results, predictions, currentRecipe.getTrainingTable()); 
     } else { 
      originalDocs = recipeToPredict.getDocumentList(); 

      Predictor predictor = new Predictor(currentRecipe, predictionColumnName); 
      newDocs = predictor.predict(originalDocs, predictionColumnName, showDists, overwrite); 
     } 

     // Predict Labels result 
     model.setDocumentList(newDocs); 
    } 

    protected DocumentList addLabelsToDocs(final String name, final boolean showDists, final boolean overwrite, DocumentList docs, TrainingResult results, List<String> predictions, FeatureTable currentFeatureTable) { 
     Map<String, List<Double>> distributions = results.getDistributions(); 
     DocumentList newDocs = docs.clone(); 
     newDocs.addAnnotation(name, predictions, overwrite); 
     if (distributions != null) { 
      if (showDists) { 
       for (String label : currentFeatureTable.getLabelArray()) { 
        List<String> dist = new ArrayList<String>(); 

        for (int i = 0; i < predictions.size(); i++) { 
         dist.add(String.format("%.3f", distributions.get(label).get(i))); 
        } 

        newDocs.addAnnotation(name + "_" + label + "_score", dist, overwrite); 
       } 
      } 
     } 
     return newDocs; 
    } 

    // ================================================== 
} 

回答

1

David。它看起來像上面複製了edu.cmu.side.recipe包的許多功能。然而,它看起來並不像你的predictSectionType()方法實際上在任何地方輸出模型的預測。

如果你想要做的是確保使用訓練有素的模型保存新數據的預測,請查看edu.cmu.side.recipe.Predictor類。它需要一個經過訓練的模型路徑作爲輸入,它被scripts/predict.sh便利腳本使用,但如果需要以編程方式調用它,則可以重新調整其主要方法。

我希望這有助於!

+0

追蹤完所有代碼後,我只是注意到我沒有先創建Trained模型。謝謝。我會先檢查它 –

+0

問題是,我沒有訓練有素的模型,它的原因總是錯誤的。 –

相關問題