0

我想使用Spark MLlib的org.apache.spark.mllib.tree.DecisionTree,如下面的代碼,但編譯失敗。如何使用CSV文件中的數據集決策樹?

import org.apache.spark.ml.Pipeline 
import org.apache.spark.ml.classification.DecisionTreeClassifier 
import org.apache.spark.ml.classification.DecisionTreeClassificationModel 
import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} 
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 
import org.apache.spark.mllib.tree.DecisionTree 
import org.apache.spark.mllib.tree.model.DecisionTreeModel 
import org.apache.spark.mllib.util.MLUtils 
import org.apache.spark.sql.SparkSession 

val sqlContext = new org.apache.spark.sql.SQLContext(sc) 
val data = sqlContext.read.format("csv").load("C:/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/airlines.txt") 
val df = sqlContext.read.csv("C:/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/airlines.txt") 
val dataframe = sqlContext.createDataFrame(df).toDF("label"); 
val splits = data.randomSplit(Array(0.7, 0.3)) 

val (trainingData, testData) = (splits(0), splits(1)) 

val numClasses = 2 
val categoricalFeaturesInfo = Map[Int, Int]() 
val impurity = "gini" 
val maxDepth = 5 
val maxBins = 32 
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,impurity, maxDepth, maxBins) 

編譯失敗,出現以下錯誤信息:

<console>:44: error: overloaded method value trainClassifier with alternatives: (input: org.apache.spark.api.java.JavaRDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: java.util.Map[Integer,Integer],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel
(input: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: scala.collection.immutable.Map[Int,Int],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel cannot be applied to (org.apache.spark.sql.Dataset[org.apache.spark.sql.Row], Int, scala.collection.immutable.Map[Int,Int], String, Int, Int) val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,impurity, maxDepth, maxBins)

+1

我得到這個錯誤「重載方法值trainClassifier替代品」當我運行上面的代碼。如果你能解決這個問題,那將是非常好的。 –

回答

1

您使用舊型RDD DecisionTree星火SQL的新數據集的API,因此編譯錯誤:

cannot be applied to (org.apache.spark.sql.Dataset[org.apache.spark.sql.Row], Int, scala.collection.immutable.Map[Int,Int], String, Int, Int) val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,impurity, maxDepth, maxBins)

注第一個輸入參數是org.apache.spark.sql.Dataset[org.apache.spark.sql.Row],但是DecisionTree需要org.apache.spark.api.java.JavaRDD[org.apache.spark.mllib.regression.LabeledPoint]

報價Announcement: DataFrame-based API is primary API

As of Spark 2.0, the RDD-based APIs in the spark.mllib package have entered maintenance mode. The primary Machine Learning API for Spark is now the DataFrame-based API in the spark.ml package.

根據Decision trees請更改代碼:

The spark.ml implementation supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions or even billions of instances.