1

我爲Strings創建了自定義Aggregator[]在多個列上應用自定義Spark聚合器(Spark 2.0)

我想將它應用於DataFrame的所有列,其中所有列都是字符串,但列號是任意的。

我被困在寫正確的表達。我想寫這樣的:

df.agg(df.columns.map(c => myagg(df(c))) : _*) 

這顯然是錯誤的給予各種接口。

我看了一下RelationalGroupedDataset.agg(expr: Column, exprs: Column*)的代碼,但是我不熟悉表達式的操作。

有什麼想法?

+3

請顯示您的聚合器代碼。並解釋你正在嘗試做什麼。 –

+0

@AssafMendelson,實際上我們計劃爲各種數據類型提供各種統計數據的自定義聚合器。我從一個聚合器開始取得最短和最長的字符串:class ShortestLongestAggregator()擴展了Aggregator [String,(String,String),(String,String)]。現在我想爲任意數據框的所有列(因爲它只有字符串列)擁有所有(最短,最長)對。 – mathieu

回答

5

與在單個字段(列)上操作的UserDefinedAggregateFunctions相反,Aggregtors需要完整的 /值。

如果你想和Aggregator它可以在你的代碼段中使用它必須通過列名稱參數化並使用作爲值類型。

import org.apache.spark.sql.expressions.Aggregator 
import org.apache.spark.sql.{Encoder, Encoders, Row} 

case class Max(col: String) 
    extends Aggregator[Row, Int, Int] with Serializable { 

    def zero = Int.MinValue 
    def reduce(acc: Int, x: Row) = 
    Math.max(acc, Option(x.getAs[Int](col)).getOrElse(zero)) 

    def merge(acc1: Int, acc2: Int) = Math.max(acc1, acc2) 
    def finish(acc: Int) = acc 

    def bufferEncoder: Encoder[Int] = Encoders.scalaInt 
    def outputEncoder: Encoder[Int] = Encoders.scalaInt 
} 

用法示例:

val df = Seq((1, None, 3), (4, Some(5), -6)).toDF("x", "y", "z") 

@transient val exprs = df.columns.map(c => Max(c).toColumn.alias(s"max($c)")) 

df.agg(exprs.head, exprs.tail: _*) 
+------+------+------+ 
|max(x)|max(y)|max(z)| 
+------+------+------+ 
|  4|  5|  3| 
+------+------+------+ 

當結合靜態類型DatasetsDataset<Row>按理說Aggregators使更多的意義。

根據您的要求,您也可以使用Seq[_]累加器在單個傳遞中彙總多個列,並在單個merge調用中處理整個(記錄)。