2017-03-09 35 views
0

我不明白確定非平凡聚合器mergeExpressions函數的一般方法。 這樣的事情org.apache.spark.sql.catalyst.expressions.aggregate.Average的mergeExpresssions方法很簡單:如何爲自定義DeclarativeAggregate定義mergeExpressions(在催化劑包中)

override lazy val mergeExpressions = Seq(
    /* sum = */ sum.left + sum.right, 
    /* count = */ count.left + count.right 
) 

的mergeExpressions爲CentralMomentAgg聚合是更復雜一些。 我想要做的是創建一個WeightedStddevSamp聚合器模仿火花CentralMomentAgg。 我幾乎有它的工作,但它產生的加權標準偏差仍然有點偏離我手工計算。 我在調試時遇到了問題,因爲我不明白如何計算mergeExpressions方法的確切邏輯。 以下是我的代碼。 updateExpressions方法基於這個weighted incremental algorithm,所以我很確定該方法是正確的。我相信我的問題在mergeExpressions方法中。任何提示將不勝感激。

abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate { 

    override def children: Seq[Expression] = Seq(child, weight) 
    override def nullable: Boolean = true 
    override def dataType: DataType = DoubleType 
    override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) 

    protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)() 
    protected val mean = AttributeReference("mean", DoubleType, nullable = false)() 
    protected val s = AttributeReference("s", DoubleType, nullable = false)() 
    override val aggBufferAttributes = Seq(wSum, mean, s) 
    override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0)) 

    // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 
    override val updateExpressions: Seq[Expression] = { 

    val newWSum = wSum + weight 
    val newMean = mean + (weight/newWSum) * (child - mean) 
    val newS = s + weight * (child - mean) * (child - newMean) 

    Seq(
     If(IsNull(child), wSum, newWSum), 
     If(IsNull(child), mean, newMean), 
     If(IsNull(child), s, newS) 
    ) 
    } 

    override val mergeExpressions: Seq[Expression] = { 
    val wSum1 = wSum.left 
    val wSum2 = wSum.right 
    val newWSum = wSum1 + wSum2 
    val delta = mean.right - mean.left 
    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum) 
    val newMean = mean.left + wSum1/newWSum * delta    // ??? 
    val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN  // ??? 
    Seq(newWSum, newMean, newS) 
    } 
} 


// Compute the weighted sample standard deviation of a column 
case class WeightedStddevSamp(child: Expression, weight: Expression) 
    extends WeightedCentralMomentAgg(child, weight) { 

    override val evaluateExpression: Expression = { 
    If(wSum === Literal(0.0), Literal.create(null, DoubleType), 
     If(wSum === Literal(1.0), Literal(Double.NaN), 
     Sqrt(s/wSum))) 
    } 

    override def prettyName: String = "wtd_stddev_samp" 
} 

回答

0

最後的結果我發現瞭如何編寫加權標準偏差mergeExpressions功能。我其實是對的,但後來在evaluateExpression中使用了總體方差而不是樣本方差計算。下面顯示的實現給出了與上面相同的結果,但它更易於理解。

override val mergeExpressions: Seq[Expression] = { 
    val newN = n.left + n.right 
    val wSum1 = wSum.left 
    val wSum2 = wSum.right 
    val newWSum = wSum1 + wSum2 
    val delta = mean.right - mean.left 

    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum) 
    val newMean = mean.left + deltaN * wSum2 
    val newS = (((wSum1 * s.left) + (wSum2 * s.right))/newWSum) + (wSum1 * wSum2 * deltaN * deltaN) 

    Seq(newN, newWSum, newMean, newS) 
} 

這裏有一些參考

Davies的文章給出了該方法的概要,但對於許多非平凡的聚合器,我認爲mergeExpressions函數可能相當複雜並且涉及高級數學來確定正確和有效的解決方案。幸運的是,在這種情況下,我找到了解決問題的人。

該解決方案與我手工製作的相匹配。需要注意的是,如果您想要樣本方差而不是總體方差,那麼需要稍微修改evaluateExpression(爲s /((n-1)* wSum/n))。

2

對於任何散列聚集,它分爲四個步驟:

1)初始化在一個分區中的緩衝液(WSUM,平均值,S)

2),更新密鑰的緩衝器給出所有輸入(調用每個輸入的updateExpression)

3)混洗後,使用mergeExpression合併相同密鑰的所有緩衝區。 wSum.left意味着WSUM左緩衝,wSum.right意味着WSUM在其他緩衝

4)得到緩衝區使用valueExpression