2017-10-14 46 views
0

我想在Spark中同時計算多個列的模式,並使用此計算值來計算DataFrame中的錯誤。我發現如何計算例如一個意思,但我認爲一種模式更復雜。計算多列的模式

這是一個平均的計算:

val multiple_mean = df.na.fill(df.columns.zip(
    df.select(intVars.map(mean(_)): _*).first.toSeq 
).toMap) 

我能計算出蠻力方式的模式:

var list = ArrayBuffer.empty[Float] 

for(column <- df.columns){ 
    list += df.select(column).groupBy(col(column)).count().orderBy(desc("count")).first.toSeq(0).asInstanceOf[Float] 
} 

val multiple_mode = df.na.fill(df.columns.zip(list.toSeq).toMap) 

如果我們考慮性能有什麼方法是最好的?

謝謝你的幫助。

回答

2

您可以使用UserDefinedAggregateFunction。下面的代碼在火花1.6.2中測試

首先創建一個擴展UserDefinedAggregateFunction的類。

import org.apache.spark.sql.Row 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 
import org.apache.spark.sql.types._ 

class ModeUDAF extends UserDefinedAggregateFunction{ 

    override def dataType: DataType = StringType 

    override def inputSchema: StructType = new StructType().add("input", StringType) 

    override def deterministic: Boolean = true 

    override def bufferSchema: StructType = new StructType().add("mode", MapType(StringType, LongType)) 

    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = Map.empty[Any, Long] 
    } 

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
    val buff0 = buffer.getMap[Any, Long](0) 
    val inp = input.get(0) 
    buffer(0) = buff0.updated(inp, buff0.getOrElse(inp, 0L) + 1L) 
    } 

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    val mp1 = buffer1.getMap[Any, Long](0) 
    val mp2 = buffer2.getMap[Any, Long](0) 

    buffer1(0) = mp1 ++ mp2.map { case (k, v) => k -> (v + mp1.getOrElse(k, 0L)) } 
    } 

    override def evaluate(buffer: Row): Any = { 
    lazy val st = buffer.getMap[Any, Long](0).toStream 
    val mode = st.foldLeft(st.head){case (e, s) => if (s._2 > e._2) s else e} 
    mode._1 
    } 

} 

後續字符可以按照以下方式在數據框中使用它。

val modeColumnList = List("some", "column", "names") // or df.columns.toList 
val modeAgg = new ModeUDAF() 
val aggCols = modeColumnList.map(c => modeAgg(df(c))) 
val aggregatedModeDF = df.agg(aggCols.head, aggCols.tail: _*) 
aggregatedModeDF.show() 

你也可以在最後的數據框上使用.collect來收集一個scala數據結構的結果。

注意:此解決方案的性能取決於輸入列的基數。

+0

謝謝,我發現它只有在基數低時才合理。我在我生成的數據上嘗試這種方法,其中每個類別只有1,2,3個值,並且此方法非常慢。 –