2017-02-27 40 views
3

在斯卡拉/星火,有一個數據幀:斯卡拉/星火dataframes:找到對應的最大列名

val dfIn = sqlContext.createDataFrame(Seq(
    ("r0", 0, 2, 3), 
    ("r1", 1, 0, 0), 
    ("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2") 

我想計算一個新列maxCol持有名稱相應列的到最大值(每行)。在這個例子中,輸出應該是:

+---+---+---+---+------+ 
| id| c0| c1| c2|maxCol| 
+---+---+---+---+------+ 
| r0| 0| 2| 3| c2| 
| r1| 1| 0| 0| c0| 
| r2| 0| 2| 2| c1| 
+---+---+---+---+------+ 

其實數據幀有60多列。因此需要一個通用的解決方案。

在Python熊貓(是的,我知道,我應該pyspark比較...)的等效可能是:

dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1) 
+1

你一般有多少列? – mrsrinivas

+0

我有大約60列 – ivankeller

+0

最多可以比較多少列? – mrsrinivas

回答

7

有了一個小竅門,你可以使用greatest功能。所需進口:

import org.apache.spark.sql.functions.{col, greatest, lit, struct} 

首先,讓我們創建的structs,其中第一個元素是值,而第二個列名的列表:

val structs = dfIn.columns.tail.map(
    c => struct(col(c).as("v"), lit(c).as("k")) 
) 

結構這樣可以傳遞給greatest如下:

dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k")) 
+---+---+---+---+------+ 
| id| c0| c1| c2|maxCol| 
+---+---+---+---+------+ 
| r0| 0| 2| 3| c2| 
| r1| 1| 0| 0| c0| 
| r2| 0| 2| 2| c2| 
+---+---+---+---+------+ 

請注意,在關係的情況下,它會取出序列中後面出現的元素(按照字典順序(x, "c2") > (x, "c1"))。通過coalescing

import org.apache.spark.sql.functions.when 

val max_col = structs.reduce(
    (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2) 
).getItem("k") 

dfIn.withColumn("maxCol", max_col) 
+---+---+---+---+------+ 
| id| c0| c1| c2|maxCol| 
+---+---+---+---+------+ 
| r0| 0| 2| 3| c2| 
| r1| 1| 0| 0| c0| 
| r2| 0| 2| 2| c1| 
+---+---+---+---+------+ 

nullable列的情況下,你必須調整此,例如值-Inf:如果由於某種原因,這是不能接受的,你可以用when明確減少。