Spark Scala:计算连续月份

Spark Scala: Count Consecutive Months

我有以下 DataFrame 示例:

Provider  Patient  Date
Smith     John     2016-01-23
Smith     John     2016-02-20
Smith     John     2016-03-21
Smith     John     2016-06-25
Smith     Jill     2016-02-01
Smith     Jill     2016-03-10
James     Jill     2017-04-10
James     Jill     2017-05-11

我想以编程方式添加一个列,指示患者连续看医生的月数。新的 DataFrame 看起来像这样:

Provider  Patient  Date         consecutive_id
Smith     John     2016-01-23   3
Smith     John     2016-02-20   3
Smith     John     2016-03-21   3
Smith     John     2016-06-25   1
Smith     Jill     2016-02-01   2
Smith     Jill     2016-03-10   2
James     Jill     2017-04-10   2
James     Jill     2017-05-11   2

我假设有一种方法可以通过 Window 函数来实现这一点,但我还没有弄清楚,我期待着社区的洞察力提供。谢谢。

你可以:

  1. 将您的日期重新格式化为整数(2016-01 = 1, 2016-02 = 2, 2017-01 = 13 ...等等)
  2. 用window和collect_list将所有日期组合成一个数组:

    val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date") df.withColumn("Dates", collect_list("Date").over(winSpec))

  3. 将数组作为 UDF 传递给修改后的@marios 版本 spark.udf.register 以获得最大连续月数

至少有3种方法可以得到结果

  1. 在SQL
  2. 中实现逻辑
  3. 将 Spark API 用于窗口函数 - .over(windowSpec)
  4. 直接使用。rdd.mapPartitions

Introducing Window Functions in Spark SQL

对于所有解决方案,您都可以调用 .toDebugString 来查看幕后操作。

SQL解决方案如下

val my_df = List(
  ("Smith", "John", "2016-01-23"),
  ("Smith", "John", "2016-02-20"),
  ("Smith", "John", "2016-03-21"),
  ("Smith", "John", "2016-06-25"),
  ("Smith", "Jill", "2016-02-01"),
  ("Smith", "Jill", "2016-03-10"),
  ("James", "Jill", "2017-04-10"),
  ("James", "Jill", "2017-05-11")
  ).toDF(Seq("Provider", "Patient", "Date"): _*)

my_df.createOrReplaceTempView("tbl")

val q = """
select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id
  from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp
          from (select t0.*,
                       case
                          when cast(yyyymm as int) - 
                               cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1
                          then 0
                          else 1
                       end x
                  from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2
"""

sql(q).show
sql(q).rdd.toDebugString

输出

scala> sql(q).show
+--------+-------+----------+------+---+---+--------------+
|Provider|Patient|      Date|yyyymm|  x|grp|consecutive_id|
+--------+-------+----------+------+---+---+--------------+
|   Smith|   Jill|2016-02-01|201602|  1|  1|             2|
|   Smith|   Jill|2016-03-10|201603|  0|  1|             2|
|   James|   Jill|2017-04-10|201704|  1|  1|             2|
|   James|   Jill|2017-05-11|201705|  0|  1|             2|
|   Smith|   John|2016-01-23|201601|  1|  1|             3|
|   Smith|   John|2016-02-20|201602|  0|  1|             3|
|   Smith|   John|2016-03-21|201603|  0|  1|             3|
|   Smith|   John|2016-06-25|201606|  1|  2|             1|
+--------+-------+----------+------+---+---+--------------+

更新

.mapPartitions + .over(windowSpec) 的混合

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}

val schema = new StructType().add(
             StructField("provider", StringType, true)).add(
             StructField("patient", StringType, true)).add(
             StructField("date", StringType, true)).add(
             StructField("x", IntegerType, true)).add(
             StructField("grp", IntegerType, true))

def f(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row("_", "_", "000000", 0, 0))
  {
    case (x1, x2) =>

    val x = 
    if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt ==
        x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1) 
    (0) else (1);

    val grp = x1.getInt(4) + x;

    Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp);
  }.drop(1)
}

val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient")
                                        .sortWithinPartitions($"date")
                                        .rdd.mapPartitions(f, true), schema)

import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy($"provider", $"patient", $"grp")
df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
     ).orderBy($"provider", $"patient", $"date").show

输出

scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
     |      ).orderBy($"provider", $"patient", $"date").show
+--------+-------+----------+---+---+--------------+
|provider|patient|      date|  x|grp|consecutive_id|
+--------+-------+----------+---+---+--------------+
|   James|   Jill|2017-04-10|  1|  1|             2|
|   James|   Jill|2017-05-11|  0|  1|             2|
|   Smith|   Jill|2016-02-01|  1|  1|             2|
|   Smith|   Jill|2016-03-10|  0|  1|             2|
|   Smith|   John|2016-01-23|  1|  1|             3|
|   Smith|   John|2016-02-20|  0|  1|             3|
|   Smith|   John|2016-03-21|  0|  1|             3|
|   Smith|   John|2016-06-25|  1|  2|             1|
+--------+-------+----------+---+---+--------------+