2017-08-15 52 views
0

我是新的火花,我想用它隨機森林分類器。 我使用libsvm格式的Iris數據來構建模型。火花隨機森林分類器 - 獲取標籤爲字符串

我的問題是 - 我怎樣才能將標籤作爲字符串? (在這種情況下 - 標籤是鳶尾花的類型)。

當數據轉換爲libsvm格式時,每個標籤都會得到一個代表它的整數,但我不知道如何返回到字符串標籤。

是否有可能與libsvm?或者我應該使用另一種格式?

這裏是我的代碼:

public PipelineModel runRandomForestAlgorithm(String dataPath) { 

System.setProperty("hadoop.home.dir", "C:/hadoop"); 
SparkSession spark = 
    SparkSession.builder().appName("JavaRandomForestClassifierExample").master("local[*]").getOrCreate(); 

/* Load and parse the data file, converting it to a DataFrame. */ 
DataFrameReader dataFrameReader = spark.read().format("libsvm"); 
Dataset<Row> data = dataFrameReader.load(dataPath); 

/* Index labels, adding metadata to the label column. 
    Fit on whole dataset to include all labels in index. */ 
StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data); 

/* Automatically identify categorical features, and index them. 
    Set maxCategories so features with > 4 distinct values are treated as continuous. */ 
VectorIndexerModel featureIndexer = 
    new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data); 

/* Split the data into training and test sets (30% held out for testing) */ 
Dataset<Row>[] splits = data.randomSplit(new double[]{0.9, 0.1}); 
Dataset<Row> trainingData = splits[0]; 
testData = splits[1]; 

/* Train a RandomForest model. */ 
RandomForestClassifier rf = 
    new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10); 

/* Convert indexed labels back to original labels. */ 
IndexToString labelConverter = 
    new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels()); 

/* Chain indexers and forest in a Pipeline */ 
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, featureIndexer, rf, labelConverter}); 

/* Train model. This also runs the indexers. */ 
PipelineModel model = pipeline.fit(trainingData); 

/* Make predictions. */ 
Dataset<Row> predictions = model.transform(testData); 

/* Select example rows to display. */ 
List<Row> predictionAsRows = 
    predictions.select("predictedLabel", "label", "features", "rawPrediction", "probability").collectAsList(); 

predictionAsRows.forEach(row -> { 
    System.out.println("predictedLabel: " + row.get(0) + " , " + "label: " + row.get(1) + " , " + "features: " + row.get(2) + " , " + 
     "predictions: " + row.get(3) + " , " + "probabilities: " + row.get(4)); 
}); 

這裏是輸出:

predictedLabel: 1.0 , label: 1.0 , features: (4,[0,1,2,3], 
    [-0.833333,0.333333,-1.0,-0.916667]) , predictions: [10.0,0.0,0.0] , 
    probabilities: [1.0,0.0,0.0] 
    predictedLabel: 1.0 , label: 1.0 , features: (4,[0,1,2,3],         
    [-0.555556,0.166667,-0.830508,-0.916667]) , predictions: [10.0,0.0,0.0] 
    , probabilities: [1.0,0.0,0.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [-0.333333,-0.75,0.0169491,-4.03573E-8]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [-0.166667,-0.416667,-0.0169491,-0.0833333]) , predictions: 
    [0.0,0.0,10.0] , probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [0.166667,-0.25,0.118644,-4.03573E-8]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [0.277778,-0.166667,0.152542,0.0833333]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,2,3], 
    [0.5,0.254237,0.0833333]) , predictions: [0.0,0.0,10.0] , probabilities: 
    [0.0,0.0,1.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [-0.166667,-0.416667,0.38983,0.5]) , predictions: [0.0,9.875,0.125] ,   
    probabilities: [0.0,0.9875,0.0125] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [0.555555,-0.166667,0.661017,0.666667]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [0.833333,-0.166667,0.898305,0.666667]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,2,3], 
    [0.222222,0.38983,0.583333]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,2,3], 
    [0.388889,0.661017,0.833333]) , predictions: [0.0,10.0,0.0] , probabilities: [0.0,1.0,0.0] 

回答

0

使用SVM的格式,你只能得到每個類的整數,所以你不能得到一個字符串從那裏的類標籤。

您可以通過使用setLabels()方法來使用IndexToString()轉換器。只需輸入您擁有的標籤數組。爲此,您應該刪除StringIndexerModel()(因爲類是數字而不是字符串,所以不需要)。例如:

String[] labels = {"Setosa", "Versicolor", "Virginica"}; 
IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("pred‌​ictedLabel").setLabe‌​ls(labels); 

或者,您可以在那裏你映射整數到字符串標籤創建一個單獨的Map。對於虹膜數據集可能是這樣的:

Map labels = new HashMap(); 
labels.put(1, "Setosa"); 
labels.put(2, "Versicolour"); 
labels.put(3, "Virginica"); 

然後你就可以使用這個Map得到串標記所有Spark轉換完成後。

希望它有幫助。

+0

地圖可能非常有用,但我不知道如何將此地圖加入Spark對象。我添加了這些線條,並且幫助了很多:'String [] labels = new String [] {「Iris-Setosa」,「Iris-versicolor」,「Iris-virginica」}; IndexToString stringConverter = new IndexToString()。setLabels(labels); /*將索引標籤轉換回原始標籤。 */ IndexToString labelConverter = new IndexToString()。setInputCol(「prediction」)。setOutputCol(「predictedLabel」)。setLabels(stringConverter.getLabels());' – Shimrit

+0

@Shimrit'Map'只能在完成所有轉換後單獨使用,因此,'IndexToString()'更喜歡。我會更新答案以反映這一點。請考慮通過點擊複選標記來接受答案,如果它對你有幫助。 :) – Shaido