0
我想用org.apache.spark.ml.classification.MultilayerPerceptronClassifier做一個多類分類。下面給出的是我使用的代碼。我有262個功能,我必須將這些功能列提供給MultilayerPerceptronClassifier。有人可以向我解釋爲MultilayerPerceptronClassifier提供功能的方法嗎?Apache Spark MultilayerPerceptronClassifier設置功能
我可以使用setFeaturesCol()方法給出功能,但它是不可行的,因爲通過使用它,我一次只能添加一個功能,但我有262個功能。
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.DataFrame;
public class NN {
final static String RESPONSE_VARIABLE = "Activity";
public static void main(String args[]){
// Load training data
SparkConf sparkConf = new SparkConf();
sparkConf.setAppName("test-client").setMaster("local[2]");
sparkConf.set("spark.driver.allowMultipleContexts", "true");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(javaSparkContext);
// Convert data in csv format to Spark data frame
DataFrame trainDataFrame = sqlContext.read().format("com.databricks.spark.csv")
.option("inferSchema", "true")
.option("header", "true")
.load("/home/thamali/Desktop/Project/csv/libsvm/train.csv");
DataFrame testDataFrame = sqlContext.read().format("com.databricks.spark.csv")
.option("inferSchema", "true")
.option("header", "true")
.load("/home/thamali/Desktop/Project/csv/libsvm/train.csv");
String [] predictors = trainDataFrame.columns();
predictors = ArrayUtils.removeElement(predictors, RESPONSE_VARIABLE);
// specify layers for the neural network:
// input layer of size 4 (features), two intermediate of size 5 and 4
// and output of size 3 (classes)
int[] layers = new int[] {262, 50, 40, 12};
// create the trainer and set its parameters
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(128)
.setSeed(1234L)
.setMaxIter(100);
// train the model
MultilayerPerceptronClassificationModel model = trainer.fit(trainDataFrame);
// compute accuracy on the test set
DataFrame result = model.transform(testDataFrame);
DataFrame predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
}
}