2017-02-21 56 views
1

你好我的具體要求,我想寫一個UDAF,它只是收集所有的輸入行。如何編寫一個簡單的行集合的Spark UDAF?

輸入是一個兩列的行,Double Type;

中級架構,「我想」,是的ArrayList(糾正我,如果我錯了)

返回的數據類型是ArrayList的

我寫我的UDAF的「想法」,但我希望有人幫我完成它。

class CollectorUDAF() extends UserDefinedAggregateFunction { 

    // Input Data Type Schema 
    def inputSchema: StructType = StructType(Array(StructField("value", DoubleType), StructField("y", DoubleType))) 

    // Intermediate Schema 
    def bufferSchema = util.ArrayList[Array(StructField("value", DoubleType), StructField("y", DoubleType)] 

    // Returned Data Type . 
    def dataType: DataType = util.ArrayList[Array(StructField("value", DoubleType), StructField("y", DoubleType)] 

    // Self-explaining 
    def deterministic = true 

    // This function is called whenever key changes 
    def initialize(buffer: MutableAggregationBuffer) = { 

    } 

    // Iterate over each entry of a group 
    def update(buffer: MutableAggregationBuffer, input: Row) = { 


    } 

    // Called after all the entries are exhausted. 
    def evaluate(buffer: Row) = { 

    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 

    } 

}

回答

2

如果我明白你的問題正確,下面應是您的解決方案:

class CollectorUDAF() extends UserDefinedAggregateFunction { 

    // Input Data Type Schema 
    def inputSchema: StructType = new StructType().add("value", DataTypes.DoubleType).add("y", DataTypes.DoubleType) 

    // Intermediate Schema 
    val bufferFields : util.ArrayList[StructField] = new util.ArrayList[StructField] 
    val bufferStructField : StructField = DataTypes.createStructField("array", DataTypes.createArrayType(DataTypes.StringType, true), true) 
    bufferFields.add(bufferStructField) 
    def bufferSchema: StructType = DataTypes.createStructType(bufferFields) 

    // Returned Data Type . 
    def dataType: DataType = DataTypes.createArrayType(DataTypes.DoubleType) 

    // Self-explaining 
    def deterministic = true 

    // This function is called whenever key changes 
    def initialize(buffer: MutableAggregationBuffer) = { 
    buffer(0, new java.util.ArrayList[Double]) 
    } 

    // Iterate over each entry of a group 
    def update(buffer: MutableAggregationBuffer, input: Row) = { 
    val DoubleList: util.ArrayList[Double] = new util.ArrayList[Double](buffer.getList(0)) 
    DoubleList.add(input.getDouble(0)) 
    DoubleList.add(input.getDouble(1)) 
    buffer.update(0, DoubleList) 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    buffer1.update(0, buffer1.getList(0).toArray() ++ buffer2.getList(0).toArray()) 
    } 
    // Called after all the entries are exhausted. 
    def evaluate(buffer: Row) = { 
    buffer.getList(0).toArray() 
    } 
} 
相關問題