我不明白確定非平凡聚合器mergeExpressions函數的一般方法。 這樣的事情org.apache.spark.sql.catalyst.expressions.aggregate.Average的mergeExpresssions方法很簡單:如何爲自定義DeclarativeAggregate定義mergeExpressions(在催化劑包中)
override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
的mergeExpressions爲CentralMomentAgg聚合是更復雜一些。 我想要做的是創建一個WeightedStddevSamp聚合器模仿火花CentralMomentAgg。 我幾乎有它的工作,但它產生的加權標準偏差仍然有點偏離我手工計算。 我在調試時遇到了問題,因爲我不明白如何計算mergeExpressions方法的確切邏輯。 以下是我的代碼。 updateExpressions方法基於這個weighted incremental algorithm,所以我很確定該方法是正確的。我相信我的問題在mergeExpressions方法中。任何提示將不勝感激。
abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(child, weight)
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)()
protected val mean = AttributeReference("mean", DoubleType, nullable = false)()
protected val s = AttributeReference("s", DoubleType, nullable = false)()
override val aggBufferAttributes = Seq(wSum, mean, s)
override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0))
// See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
override val updateExpressions: Seq[Expression] = {
val newWSum = wSum + weight
val newMean = mean + (weight/newWSum) * (child - mean)
val newS = s + weight * (child - mean) * (child - newMean)
Seq(
If(IsNull(child), wSum, newWSum),
If(IsNull(child), mean, newMean),
If(IsNull(child), s, newS)
)
}
override val mergeExpressions: Seq[Expression] = {
val wSum1 = wSum.left
val wSum2 = wSum.right
val newWSum = wSum1 + wSum2
val delta = mean.right - mean.left
val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum)
val newMean = mean.left + wSum1/newWSum * delta // ???
val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN // ???
Seq(newWSum, newMean, newS)
}
}
// Compute the weighted sample standard deviation of a column
case class WeightedStddevSamp(child: Expression, weight: Expression)
extends WeightedCentralMomentAgg(child, weight) {
override val evaluateExpression: Expression = {
If(wSum === Literal(0.0), Literal.create(null, DoubleType),
If(wSum === Literal(1.0), Literal(Double.NaN),
Sqrt(s/wSum)))
}
override def prettyName: String = "wtd_stddev_samp"
}