1
我已經設置LightSIDE插件,並且可以正常運行,但我不知道爲什麼我無法將我的數據保存到空文件?這是我做的一個簡單的結構。Java LightSIDE - 如何使用LightSIDE對數據進行分類?
- 活動是需要進行分類的列表數據。
- 我有3個類別,每個類別都有各自的類型。
- 我已經用特定的單詞列表定義了每個類別。例如:食物({壽司,食物,日本},{Cap Jay,食物,中國},{慢跑,運動,跑步},...)
這就是我如何使用LightSIDE 。
public void predictSectionType(String[] sections, List<String> activityList) {
LightSideService currentLightsideHelper = new LightSideService();
Recipe newRecipe;
// Initialize SIDEPlugin
currentLightsideHelper.initSIDEPlugin();
try {
// Load Recipe with Extracted Features & Trained Models
ClassLoader myClassLoader = getClass().getClassLoader();
newRecipe = ConverterControl.readFromXML(new InputStreamReader(myClassLoader.getResourceAsStream("static/lightsideTrainingResult/trainingData.xml")));
// Predict Result Data
Recipe recipeToPredict = currentLightsideHelper.loadNewDocumentsFromCSV(sections); // DocumentList & Recipe Created
currentLightsideHelper.predictLabels(recipeToPredict, newRecipe);
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
我把LightSideService類作爲LightSIDE函數的Summary類。
public class LightSideService {
// Extract Features Parameters
final String featureTableName = "1Grams";
final int featureThreshold = 2;
final String featureAnnotation = "Code";
final Type featureType = Type.NOMINAL;
// Build Models Parameters
final String trainingResultName = "Bayes_1Grams";
// Predict Labels Parameters
final String predictionColumnName = featureAnnotation + "_Prediction";
final boolean showMaxScore = false;
final boolean showDists = true;
final boolean overwrite = false;
final boolean useEvaluation = false;
public DocumentListTableModel model = new DocumentListTableModel(null);
public Map<String, Serializable> validationSettings = new TreeMap<String, Serializable>();
public Map<FeaturePlugin, Boolean> featurePlugins = new HashMap<FeaturePlugin, Boolean>();
public Map<LearningPlugin, Boolean> learningPlugins = new HashMap<LearningPlugin, Boolean>();
public Collection<ModelMetricPlugin> modelEvaluationPlugins = new ArrayList<ModelMetricPlugin>();
public Map<WrapperPlugin, Boolean> wrapperPlugins = new HashMap<WrapperPlugin, Boolean>();
// Initialize Data ==================================================
public void initSIDEPlugin() {
SIDEPlugin[] featureExtractors = PluginManager.getSIDEPluginArrayByType("feature_hit_extractor");
boolean selected = true;
for (SIDEPlugin fe : featureExtractors) {
featurePlugins.put((FeaturePlugin) fe, selected);
selected = false;
}
SIDEPlugin[] learners = PluginManager.getSIDEPluginArrayByType("model_builder");
for (SIDEPlugin le : learners) {
learningPlugins.put((LearningPlugin) le, true);
}
SIDEPlugin[] tableEvaluations = PluginManager.getSIDEPluginArrayByType("model_evaluation");
for (SIDEPlugin fe : tableEvaluations) {
modelEvaluationPlugins.add((ModelMetricPlugin) fe);
}
SIDEPlugin[] wrappers = PluginManager.getSIDEPluginArrayByType("learning_wrapper");
for (SIDEPlugin wr : wrappers) {
wrapperPlugins.put((WrapperPlugin) wr, false);
}
}
//Used to Train Models, adjust parameters according to model
public void initValidationSettings(Recipe currentRecipe) {
validationSettings.put("testRecipe", currentRecipe);
validationSettings.put("testSet", currentRecipe.getDocumentList());
validationSettings.put("annotation", "Age");
validationSettings.put("type", "CV");
validationSettings.put("foldMethod", "AUTO");
validationSettings.put("numFolds", 10);
validationSettings.put("source", "RANDOM");
validationSettings.put("test", "true");
}
// Load CSV Doc ==================================================
public Recipe loadNewDocumentsFromCSV(String filePath) {
DocumentList testDocs;
testDocs = chooseDocumentList(filePath);
if (testDocs != null) {
testDocs.guessTextAndAnnotationColumns();
Recipe currentRecipe = Recipe.fetchRecipe();
currentRecipe.setDocumentList(testDocs);
return currentRecipe;
}
return null;
}
public Recipe loadNewDocumentsFromCSV(String[] rootCauseList) {
DocumentList testDocs;
testDocs = chooseDocumentList(rootCauseList);
if (testDocs != null) {
testDocs.guessTextAndAnnotationColumns();
Recipe currentRecipe = Recipe.fetchRecipe();
currentRecipe.setDocumentList(testDocs);
return currentRecipe;
}
return null;
}
protected DocumentList chooseDocumentList(String filePath) {
TreeSet<String> docNames = new TreeSet<String>();
docNames.add(filePath);
try {
DocumentList testDocs;
Charset encoding = Charset.forName("UTF-8");
{
testDocs = ImportController.makeDocumentList(docNames, encoding);
}
return testDocs;
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
protected DocumentList chooseDocumentList(String[] rootCauseList) {
try {
DocumentList testDocs;
testDocs = new DocumentList();
testDocs.setName("TestData.csv");
List<String> codes = new ArrayList();
List<String> roots = new ArrayList();
for (String s : rootCauseList) {
codes.add("");
roots.add((s != null) ? s : "");
}
testDocs.addAnnotation("Code", codes, false);
testDocs.addAnnotation("Root Cause Failure Description", roots, false);
return testDocs;
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
// Save/Load XML ==================================================
public void saveRecipeToXml(Recipe currentRecipe, String filePath) {
File f = new File(filePath);
try {
ConverterControl.writeToXML(f, currentRecipe);
} catch (Exception e) {
e.printStackTrace();
}
}
public Recipe loadRecipeFromXml(String filePath) throws FileNotFoundException, IOException {
Recipe currentRecipe = ConverterControl.loadRecipe(filePath);
return currentRecipe;
}
// Extract Features ==================================================
public Recipe prepareBuildFeatureTable(Recipe currentRecipe) {
// Add Feature Plugins
Collection<FeaturePlugin> plugins = new TreeSet<FeaturePlugin>();
for (FeaturePlugin plugin : featurePlugins.keySet()) {
String pluginString = plugin.toString();
if (pluginString == "Basic Features" || pluginString == "Character N-Grams") {
plugins.add(plugin);
}
}
// Generate Plugin into Recipe
currentRecipe = Recipe.addPluginsToRecipe(currentRecipe, plugins);
// Setup Plugin configurations
OrderedPluginMap currentOrderedPluginMap = currentRecipe.getExtractors();
for (SIDEPlugin plugin : currentOrderedPluginMap.keySet()) {
String pluginString = plugin.toString();
Map<String, String> currentConfigurations = currentOrderedPluginMap.get(plugin);
if (pluginString == "Basic Features") {
for (String s : currentConfigurations.keySet()) {
if (s == "Unigrams" || s == "Bigrams" || s == "Trigrams" ||
s == "Count Occurences" || s == "Normalize N-Gram Counts" ||
s == "Stem N-Grams" || s == "Skip Stopwords in N-Grams") {
currentConfigurations.put(s, "true");
} else {
currentConfigurations.put(s, "false");
}
}
} else if (pluginString == "Character N-Grams") {
for (String s : currentConfigurations.keySet()) {
if (s == "Include Punctuation") {
currentConfigurations.put(s, "true");
} else if (s == "minGram") {
currentConfigurations.put(s, "3");
} else if (s == "maxGram") {
currentConfigurations.put(s, "4");
}
}
currentConfigurations.put("Extract Only Within Words", "true");
}
}
// Build FeatureTable
currentRecipe = buildFeatureTable(currentRecipe, featureTableName, featureThreshold, featureAnnotation, featureType);
return currentRecipe;
}
protected Recipe buildFeatureTable(Recipe currentRecipe, String name, int threshold, String annotation, Type type) {
FeaturePlugin activeExtractor = null;
try {
Collection<FeatureHit> hits = new HashSet<FeatureHit>();
for (SIDEPlugin plug : currentRecipe.getExtractors().keySet()) {
activeExtractor = (FeaturePlugin) plug;
hits.addAll(activeExtractor.extractFeatureHits(currentRecipe.getDocumentList(), currentRecipe.getExtractors().get(plug)));
}
FeatureTable ft = new FeatureTable(currentRecipe.getDocumentList(), hits, threshold, annotation, type);
ft.setName(name);
currentRecipe.setFeatureTable(ft);
} catch (Exception e) {
System.err.println("Feature Extraction Failed");
e.printStackTrace();
}
return currentRecipe;
}
// Build Models ==================================================
public Recipe prepareBuildModel(Recipe currentRecipe) {
try {
// Get Learner Plugins
LearningPlugin learner = null;
for (LearningPlugin plugin : learningPlugins.keySet()) {
/* if (plugin.toString() == "Naive Bayes") */
if (plugin.toString() == "Logistic Regression") {
learner = plugin;
}
}
if (Boolean.TRUE.toString().equals(validationSettings.get("test"))) {
if (validationSettings.get("type").equals("CV")) {
validationSettings.put("testSet", currentRecipe.getDocumentList());
}
}
Map<String, String> settings = learner.generateConfigurationSettings();
currentRecipe = Recipe.addLearnerToRecipe(currentRecipe, learner, settings);
currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings));
for (WrapperPlugin wrap : wrapperPlugins.keySet()) {
if (wrapperPlugins.get(wrap)) {
currentRecipe.addWrapper(wrap, wrap.generateConfigurationSettings());
}
}
buildModel(currentRecipe, validationSettings);
} catch (Exception e) {
e.printStackTrace();
}
return currentRecipe;
}
protected void buildModel(Recipe currentRecipe,
Map<String, Serializable> validationSettings) {
try {
FeatureTable currentFeatureTable = currentRecipe.getTrainingTable();
if (currentRecipe != null) {
TrainingResult results = null;
/*
* if (validationSettings.get("type").equals("SUPPLY")) {
* DocumentList test = (DocumentList)
* validationSettings.get("testSet"); FeatureTable
* extractTestFeatures = prepareTestFeatureTable(currentRecipe,
* validationSettings, test);
* validationSettings.put("testFeatureTable",
* extractTestFeatures);
*
* // if we've already trained the exact same model, don't // do
* it again. Just evaluate. Recipe cached =
* checkForCachedModel(); if (cached != null) { results =
* evaluateUsingCachedModel(currentFeatureTable,
* extractTestFeatures, cached, currentRecipe); } }
*/
if (results == null) {
results = currentRecipe.getLearner().train(currentFeatureTable, currentRecipe.getLearnerSettings(), validationSettings, currentRecipe.getWrappers());
}
if (results != null) {
currentRecipe.setTrainingResult(results);
results.setName(trainingResultName);
currentRecipe.setLearnerSettings(currentRecipe.getLearner().generateConfigurationSettings());
currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings));
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
protected static FeatureTable prepareTestFeatureTable(Recipe recipe, Map<String, Serializable> validationSettings, DocumentList test) {
prepareDocuments(recipe, validationSettings, test); // assigns classes, annotations.
Collection<FeatureHit> hits = new TreeSet<FeatureHit>();
OrderedPluginMap extractors = recipe.getExtractors();
for (SIDEPlugin plug : extractors.keySet()) {
Collection<FeatureHit> extractorHits = ((FeaturePlugin) plug).extractFeatureHits(test, extractors.get(plug));
hits.addAll(extractorHits);
}
FeatureTable originalTable = recipe.getTrainingTable();
FeatureTable ft = new FeatureTable(test, hits, 0, originalTable.getAnnotation(), originalTable.getClassValueType());
for (SIDEPlugin plug : recipe.getFilters().keySet()) {
ft = ((RestructurePlugin) plug).filterTestSet(originalTable, ft, recipe.getFilters().get(plug), recipe.getFilteredTable().getThreshold());
}
ft.reconcileFeatures(originalTable.getFeatureSet());
return ft;
}
protected static Map<String, Serializable> prepareDocuments(Recipe currentRecipe, Map<String, Serializable> validationSettings, DocumentList test) throws IllegalStateException {
DocumentList train = currentRecipe.getDocumentList();
try {
test.setCurrentAnnotation(currentRecipe.getTrainingTable().getAnnotation(), currentRecipe.getTrainingTable().getClassValueType());
test.setTextColumns(new HashSet<String>(train.getTextColumns()));
test.setDifferentiateTextColumns(train.getTextColumnsAreDifferentiated());
Collection<String> trainColumns = train.allAnnotations().keySet();
Collection<String> testColumns = test.allAnnotations().keySet();
if (!testColumns.containsAll(trainColumns)) {
ArrayList<String> missing = new ArrayList<String>(trainColumns);
missing.removeAll(testColumns);
throw new java.lang.IllegalStateException("Test set annotations do not match training set.\nMissing columns: " + missing);
}
validationSettings.put("testSet", test);
} catch (Exception e) {
e.printStackTrace();
throw new java.lang.IllegalStateException("Could not prepare test set.\n" + e.getMessage(), e);
}
return validationSettings;
}
//Predict Labels ==================================================
public void predictLabels(Recipe recipeToPredict, Recipe currentRecipe) {
DocumentList newDocs = null;
DocumentList originalDocs;
if (useEvaluation) {
originalDocs = recipeToPredict.getTrainingResult().getEvaluationTable().getDocumentList();
TrainingResult results = currentRecipe.getTrainingResult();
List<String> predictions = (List<String>) results.getPredictions();
newDocs = addLabelsToDocs(predictionColumnName, showDists, overwrite, originalDocs, results, predictions, currentRecipe.getTrainingTable());
} else {
originalDocs = recipeToPredict.getDocumentList();
Predictor predictor = new Predictor(currentRecipe, predictionColumnName);
newDocs = predictor.predict(originalDocs, predictionColumnName, showDists, overwrite);
}
// Predict Labels result
model.setDocumentList(newDocs);
}
protected DocumentList addLabelsToDocs(final String name, final boolean showDists, final boolean overwrite, DocumentList docs, TrainingResult results, List<String> predictions, FeatureTable currentFeatureTable) {
Map<String, List<Double>> distributions = results.getDistributions();
DocumentList newDocs = docs.clone();
newDocs.addAnnotation(name, predictions, overwrite);
if (distributions != null) {
if (showDists) {
for (String label : currentFeatureTable.getLabelArray()) {
List<String> dist = new ArrayList<String>();
for (int i = 0; i < predictions.size(); i++) {
dist.add(String.format("%.3f", distributions.get(label).get(i)));
}
newDocs.addAnnotation(name + "_" + label + "_score", dist, overwrite);
}
}
}
return newDocs;
}
// ==================================================
}
追蹤完所有代碼後,我只是注意到我沒有先創建Trained模型。謝謝。我會先檢查它 –
問題是,我沒有訓練有素的模型,它的原因總是錯誤的。 –