2016-12-01 20 views
3

我有使用元組編碼器與KRYO編碼器,用於對線段形式Spark 2.0.0:如何使用自定義編碼類型來聚合DataSet?

implicit def single[A](implicit c: ClassTag[A]): Encoder[A] = Encoders.kryo[A](c) 
implicit def tuple2[A1, A2](implicit 
          e1: Encoder[A1], 
          e2: Encoder[A2] 
          ): Encoder[(A1,A2)] = Encoders.tuple[A1,A2](e1, e2) 
implicit val lineStringEncoder = Encoders.kryo[LineString] 

val ds = segmentPoints.map(
    sp => { 
    val p1 = new Coordinate(sp.lon_ini, sp.lat_ini) 
    val p2 = new Coordinate(sp.lon_fin, sp.lat_fin) 
    val coords = Array(p1, p2) 

    (sp.id, gf.createLineString(coords)) 
    }) 
    .toDF("id", "segment") 
    .as[(Long, LineString)] 
    .cache 

ds.show 

    +----+--------------------+ 
    | id |  segment  | 
    +----+--------------------+ 
    | 347|[01 00 63 6F 6D 2...| 
    | 347|[01 00 63 6F 6D 2...| 
    | 347|[01 00 63 6F 6D 2...| 
    | 808|[01 00 63 6F 6D 2...| 
    | 808|[01 00 63 6F 6D 2...| 
    | 808|[01 00 63 6F 6D 2...| 
    +----+--------------------+ 

我可以在段列應用任何地圖操作和使用底層LineStrign方法存儲爲數據集[(龍,線段形式)]的一些數據。

ds.map(_._2.getClass.getName).show(false) 

+--------------------------------------+ 
|value         | 
+--------------------------------------+ 
|com.vividsolutions.jts.geom.LineString| 
|com.vividsolutions.jts.geom.LineString| 
|com.vividsolutions.jts.geom.LineString| 

我想創造一些UDAFs處理具有相同ID段,我都試過了folling兩種不同的方法沒有任何成功:

1)使用聚合:

val length = new Aggregator[LineString, Double, Double] with Serializable { 
    def zero: Double = 0      // The initial value. 
    def reduce(b: Double, a: LineString) = b + a.getLength // Add an element to the running total 
    def merge(b1: Double, b2: Double) = b1 + b2 // Merge intermediate values. 
    def finish(b: Double) = b 
    // Following lines are missing on the API doc example but necessary to get 
    // the code compile 
    override def bufferEncoder: Encoder[Double] = Encoders.scalaDouble 
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble 
}.toColumn 

ds.groupBy("id) 
    .agg(length(col("segment")).as("kms")) 
    .show(false) 

這裏我得到以下錯誤:

Exception in thread "main" org.apache.spark.sql.AnalysisException: unresolved operator 'Aggregate [id#603L], [id#603L, anon$1([email protected], None, input[0, double, true] AS value#715, cast(value#715 as double), input[0, double, true] AS value#714, DoubleType, DoubleType)['segment] AS kms#721]; 

2)使用UserDefinedAggregateFunction

class Length extends UserDefinedAggregateFunction { 
    val e = Encoders.kryo[LineString] 

    // This is the input fields for your aggregate function. 
    override def inputSchema: StructType = StructType(
    StructField("segment", DataTypes.BinaryType) :: Nil 
) 

    // This is the internal fields you keep for computing your aggregate. 
    override def bufferSchema: StructType = StructType(
     StructField("length", DoubleType) :: Nil 
) 

    // This is the output type of your aggregatation function. 
    override def dataType: DataType = DoubleType 

    override def deterministic: Boolean = true 

    // This is the initial value for your buffer schema. 
    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = 0.0 
    } 

    // This is how to update your buffer schema given an input. 
    override def update(buffer : MutableAggregationBuffer, input : Row) : Unit = { 
    // val l0 = input.getAs[LineString](0) // Can't cast to LineString (I guess because it is searialized using given encoder) 
    val b = input.getAs[Array[Byte]](0) // This works fine 
    val lse = e.asInstanceOf[ExpressionEncoder[LineString]] 
    val ls = lse.fromRow(???) // it expects InternalRow but input is a Row instance 
    // I also tried casting b.asInstance[InternalRow] without success. 
    buffer(0) = buffer.getAs[Double](0) + ls.getLength 
    } 

    // This is how to merge two objects with the bufferSchema type. 
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0) 
    } 

    // This is where you output the final value, given the final value of your bufferSchema. 
    override def evaluate(buffer: Row): Any = { 
    buffer.getDouble(0) 
    } 
} 

val length = new Length 
rseg 
    .groupBy("id") 
    .agg(length(col("segment")).as("kms")) 
    .show(false) 

我在做什麼錯?我想使用自定義類型的聚合API,而不是使用rdd groupBy API。我搜索了Spark文檔,但找不到這個問題的答案,看起來目前處於早期階段。

謝謝。

回答

0

根據這個answer,沒有簡單的方法將自定義編碼器傳遞給嵌套類型,例如(Long,LineString)。

一種選擇是,以限定case class LineStringWithID指將會id: Long屬性延伸LineString,並使用編碼器從SQLImplicits

P.S.你能把你的問題分解成更小的部分,每一個主題?

相關問題