2017-09-13 75 views
0

我是一個相對初學者Spark的東西。我有一個廣泛的數據幀(1000列),我想將列添加到基於對應的列是否有遺漏值火花柱狀表演

所以

 
+----+   
| A | 
+----+ 
| 1 | 
+----+ 
|null|  
+----+ 
| 3 | 
+----+ 

成爲

 
+----+-------+   
| A | A_MIS | 
+----+-------+ 
| 1 | 0 | 
+----+-------+ 
|null| 1 | 
+----+-------+ 
| 3 | 1 | 
+----+-------+ 

這是一部分定製ml變壓器,但算法應該清晰。

override def transform(dataset: org.apache.spark.sql.Dataset[_]): org.apache.spark.sql.DataFrame = { 
    var ds = dataset 
    dataset.columns.foreach(c => { 
    if (dataset.filter(col(c).isNull).count() > 0) { 
     ds = ds.withColumn(c + "_MIS", when(col(c).isNull, 1).otherwise(0)) 
    } 
    }) 


    ds.toDF() 
} 

循環遍歷列,如果> 0個空值創建一個新列。

傳入的數據集被緩存(使用.cache方法),相關配置設置是默認值。 現在在單檯筆記本電腦上運行,即使使用最少量的行,也可以在1000列上運行40分鐘。 我認爲這個問題是由於碰到一個數據庫造成的,所以我試着用parquet文件來取代相同的結果。看看作業用戶界面,它似乎在做文件掃描以便進行計數。

有沒有一種方法可以改進此算法以獲得更好的性能,或以某種方式調整緩存?增加spark.sql.inMemoryColumnarStorage.batchSize剛剛給我一個OOM錯誤。

回答

0

下面是修復問題的代碼。

override def transform(dataset: Dataset[_]): DataFrame = { 
    var ds = dataset 
    val rowCount = dataset.count() 
    val exprs = dataset.columns.map(count(_)) 
    val colCounts = dataset.agg(exprs.head, exprs.tail: _*).toDF(dataset.columns: _*).first() 
    dataset.columns.foreach(c => { 
    if (colCounts.getAs[Long](c) > 0 && colCounts.getAs[Long](c) < rowCount ) { 
     ds = ds.withColumn(c + "_MIS", when(col(c).isNull, 1).otherwise(0)) 
    } 
    }) 
    ds.toDF() 
}