2016-03-15 58 views
1

我設計了以下功能與任何數值類型的陣列的工作:SparkSQL功能需要類型十進制

def array_sum[T](item:Traversable[T])(implicit n:Numeric[T]) = item.sum 
// Registers a function as a UDF so it can be used in SQL statements. 
sqlContext.udf.register("array_sumD", array_sum(_:Seq[Float])) 

但是想要通過類型的數組浮箱以下錯誤:

// Now we can use our function directly in SparkSQL. 
sqlContext.sql("SELECT array_sumD(array(5.0,1.0,2.0)) as array_sum").show 

錯誤:

cannot resolve 'UDF(array(5.0,1.0,2.0))' due to data type mismatch: argument 1 requires array<double> type, however, 'array(5.0,1.0,2.0)' is of array<decimal(2,1)> type; 

回答

1

Spark-SQL中的十進制值的默認數據類型是,十進制。如果您你的文字查詢到花車,並使用相同的UDF,它的工作原理:

sqlContext.sql(
    """SELECT array_sumD(array(
    | CAST(5.0 AS FLOAT), 
    | CAST(1.0 AS FLOAT), 
    | CAST(2.0 AS FLOAT) 
    |)) as array_sum""".stripMargin).show 

結果,符合市場預期:

+---------+ 
|array_sum| 
+---------+ 
|  8.0| 
+---------+ 

或者,如果你想要使用小數(避免浮點問題),你會仍然必須使用鑄造得到正確的精度,再加上你將不會是abl e使用Scala的不錯Numericsum,因爲小數被讀作java.math.BigDecimal。所以 - 你的代碼是:

def array_sum(item:Traversable[java.math.BigDecimal]) = item.reduce((a, b) => a.add(b)) 

// Registers a function as a UDF so it can be used in SQL statements. 
sqlContext.udf.register("array_sumD", array_sum(_:Seq[java.math.BigDecimal])) 

sqlContext.sql(
    """SELECT array_sumD(array(
    | CAST(5.0 AS DECIMAL(38,18)), 
    | CAST(1.0 AS DECIMAL(38,18)), 
    | CAST(2.0 AS DECIMAL(38,18)) 
    |)) as array_sum""".stripMargin).show