2017-10-17 86 views
0

我有以下DataFrame例如:星火斯卡拉:計數連續兩個月

Provider Patient Date 
Smith  John  2016-01-23 
Smith  John  2016-02-20 
Smith  John  2016-03-21 
Smith  John  2016-06-25 
Smith  Jill  2016-02-01 
Smith  Jill  2016-03-10 
James  Jill  2017-04-10 
James  Jill  2017-05-11 

我希望以編程方式增加一列,表示有多少個連續月,病人看病。新DataFrame是這樣的:

Provider Patient Date   consecutive_id 
Smith  John  2016-01-23 3 
Smith  John  2016-02-20 3 
Smith  John  2016-03-21 3 
Smith  John  2016-06-25 1 
Smith  Jill  2016-02-01 2 
Smith  Jill  2016-03-10 2 
James  Jill  2017-04-10 2 
James  Jill  2017-05-11 2 

我假設有一種方法用Window函數來實現這一點,但我一直沒能推算出來呢,我很期待到社區可以提供的洞察力。謝謝。

回答

1

至少有3種方法得到的結果

  1. 在SQL
  2. 使用星火API實現邏輯爲窗口函數的最大數量 - .over(windowSpec)
  3. 使用直接.rdd.mapPartitions

Introducing Window Functions in Spark SQL

對於所有解決方案,您可以調用.toDebugString來查看引擎蓋下的操作。

SQL溶液低於

val my_df = List(
    ("Smith", "John", "2016-01-23"), 
    ("Smith", "John", "2016-02-20"), 
    ("Smith", "John", "2016-03-21"), 
    ("Smith", "John", "2016-06-25"), 
    ("Smith", "Jill", "2016-02-01"), 
    ("Smith", "Jill", "2016-03-10"), 
    ("James", "Jill", "2017-04-10"), 
    ("James", "Jill", "2017-05-11") 
).toDF(Seq("Provider", "Patient", "Date"): _*) 

my_df.createOrReplaceTempView("tbl") 

val q = """ 
select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id 
    from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp 
      from (select t0.*, 
         case 
          when cast(yyyymm as int) - 
           cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1 
          then 0 
          else 1 
         end x 
        from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2 
""" 

sql(q).show 
sql(q).rdd.toDebugString 

輸出

scala> sql(q).show 
+--------+-------+----------+------+---+---+--------------+ 
|Provider|Patient|  Date|yyyymm| x|grp|consecutive_id| 
+--------+-------+----------+------+---+---+--------------+ 
| Smith| Jill|2016-02-01|201602| 1| 1|    2| 
| Smith| Jill|2016-03-10|201603| 0| 1|    2| 
| James| Jill|2017-04-10|201704| 1| 1|    2| 
| James| Jill|2017-05-11|201705| 0| 1|    2| 
| Smith| John|2016-01-23|201601| 1| 1|    3| 
| Smith| John|2016-02-20|201602| 0| 1|    3| 
| Smith| John|2016-03-21|201603| 0| 1|    3| 
| Smith| John|2016-06-25|201606| 1| 2|    1| 
+--------+-------+----------+------+---+---+--------------+ 

更新

.mapPartitions的混合+ .over(windowSpec)

import org.apache.spark.sql.Row 
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} 

val schema = new StructType().add(
      StructField("provider", StringType, true)).add(
      StructField("patient", StringType, true)).add(
      StructField("date", StringType, true)).add(
      StructField("x", IntegerType, true)).add(
      StructField("grp", IntegerType, true)) 

def f(iter: Iterator[Row]) : Iterator[Row] = { 
    iter.scanLeft(Row("_", "_", "000000", 0, 0)) 
    { 
    case (x1, x2) => 

    val x = 
    if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt == 
     x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1) 
    (0) else (1); 

    val grp = x1.getInt(4) + x; 

    Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp); 
    }.drop(1) 
} 

val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient") 
             .sortWithinPartitions($"date") 
             .rdd.mapPartitions(f, true), schema) 

import org.apache.spark.sql.expressions.Window 
val windowSpec = Window.partitionBy($"provider", $"patient", $"grp") 
df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec) 
    ).orderBy($"provider", $"patient", $"date").show 

輸出

scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec) 
    |  ).orderBy($"provider", $"patient", $"date").show 
+--------+-------+----------+---+---+--------------+ 
|provider|patient|  date| x|grp|consecutive_id| 
+--------+-------+----------+---+---+--------------+ 
| James| Jill|2017-04-10| 1| 1|    2| 
| James| Jill|2017-05-11| 0| 1|    2| 
| Smith| Jill|2016-02-01| 1| 1|    2| 
| Smith| Jill|2016-03-10| 0| 1|    2| 
| Smith| John|2016-01-23| 1| 1|    3| 
| Smith| John|2016-02-20| 0| 1|    3| 
| Smith| John|2016-03-21| 0| 1|    3| 
| Smith| John|2016-06-25| 1| 2|    1| 
+--------+-------+----------+---+---+--------------+ 
+0

這適用於我提供的示例數據,這就是爲什麼我很樂意給出複選標記。我只是試圖通過一個'java.lang.ArrayIndexOutOfBoundsException:2'來試圖展示最終的'df_mod'轉換。 – bshelt141

0

,你可以:

  1. 格式化日期整數(2016-01 = 1, 2016-02 = 2, 2017-01 = 13 ...等)
  2. 把所有的日期到一個數組有一個窗口,collect_list:

    val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date") df.withColumn("Dates", collect_list("Date").over(winSpec))

  3. 將數組傳遞到@marios的修改版本solutionspark.udf.register一個UDF獲得的連續三個月