2016-11-28 96 views
0

我想用org.apache.spark.ml.classification.MultilayerPerceptronClassifier做一個多類分類。下面給出的是我使用的代碼。我有262個功能,我必須將這些功能列提供給MultilayerPerceptronClassifier。有人可以向我解釋爲MultilayerPerceptronClassifier提供功能的方法嗎?Apache Spark MultilayerPerceptronClassifier設置功能

我可以使用setFeaturesCol()方法給出功能,但它是不可行的,因爲通過使用它,我一次只能添加一個功能,但我有262個功能。

import org.apache.commons.lang3.ArrayUtils; 
import org.apache.spark.SparkConf; 
import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.sql.Row; 
import org.apache.spark.sql.SQLContext; 
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; 
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; 
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; 
import org.apache.spark.sql.DataFrame; 

public class NN { 

    final static String RESPONSE_VARIABLE = "Activity"; 
    public static void main(String args[]){ 
     // Load training data 
     SparkConf sparkConf = new SparkConf(); 
     sparkConf.setAppName("test-client").setMaster("local[2]"); 
     sparkConf.set("spark.driver.allowMultipleContexts", "true"); 
     JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf); 
     SQLContext sqlContext = new SQLContext(javaSparkContext); 

     // Convert data in csv format to Spark data frame 
     DataFrame trainDataFrame = sqlContext.read().format("com.databricks.spark.csv") 
       .option("inferSchema", "true") 
       .option("header", "true") 
       .load("/home/thamali/Desktop/Project/csv/libsvm/train.csv"); 

     DataFrame testDataFrame = sqlContext.read().format("com.databricks.spark.csv") 
       .option("inferSchema", "true") 
       .option("header", "true") 
       .load("/home/thamali/Desktop/Project/csv/libsvm/train.csv"); 

     String [] predictors = trainDataFrame.columns(); 
     predictors = ArrayUtils.removeElement(predictors, RESPONSE_VARIABLE); 


// specify layers for the neural network: 
// input layer of size 4 (features), two intermediate of size 5 and 4 
// and output of size 3 (classes) 
     int[] layers = new int[] {262, 50, 40, 12}; 
// create the trainer and set its parameters 
     MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() 
       .setLayers(layers) 
       .setBlockSize(128) 
       .setSeed(1234L) 
       .setMaxIter(100); 
// train the model 
     MultilayerPerceptronClassificationModel model = trainer.fit(trainDataFrame); 
// compute accuracy on the test set 
     DataFrame result = model.transform(testDataFrame); 
     DataFrame predictionAndLabels = result.select("prediction", "label"); 
     MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() 
       .setMetricName("accuracy"); 
     System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels)); 
    } 

} 

回答

0

我們可以使用Apache火花矢量彙編程序創建一個包含所有必要功能的向量。