2016-12-09 21 views
0

假設我有以下數據集:Pyspark ---通過增加每組值的新列

a | b 
1 | 0.4 
1 | 0.8 
1 | 0.5 
2 | 0.4 
2 | 0.1 

我想添加一個名爲「標籤」的新的列,其中值對每個本地確定a中的一組值。的一組中的最高值B一個標記爲1,其他的都標有0

輸出應該是這樣的:

a | b | label 
1 | 0.4 | 0 
1 | 0.8 | 1 
1 | 0.5 | 0 
2 | 0.4 | 1 
2 | 0.1 | 0 

我怎樣才能做到這一點有效地利用PySpark?

回答

2

你可以用窗口功能來做到這一點。首先,您需要一對夫婦的進口:

from pyspark.sql.functions import desc, row_number, when 
from pyspark.sql.window import Window 

和窗口定義:

w = Window().partitionBy("a").orderBy(desc("b")) 

最後你可以使用這些:

df.withColumn("label", when(row_number().over(w) == 1, 1).otherwise(0)) 

例如數據:

df = sc.parallelize([ 
    (1, 0.4), (1, 0.8), (1, 0.5), (2, 0.4), (2, 0.1) 
]).toDF(["a", "b"]) 

結果是:

+---+---+-----+ 
| a| b|label| 
+---+---+-----+ 
| 1|0.8| 1| 
| 1|0.5| 0| 
| 1|0.4| 0| 
| 2|0.4| 1| 
| 2|0.1| 0| 
+---+---+-----+