按连续范围组合 PySpark 中的行

Combine rows on PySpark by continues ranges

假设我有以下 DataFrame:

| username| start_range| end_range| user_type|
 ---------------------------------------------
|  a      |   1        |  99      |  admin   |
|  a      |   100      |  200     |  admin   |
|  b      |   100      |  150     |  reader  |
|  a      |   300      |  350     |  admin   |
|  b      |   200      |  300     |  reader  |

我想按连续范围合并行以获得以下 DataFrame:

| username| start_range| end_range| user_type|
 ---------------------------------------------
|  a      |   1        |  200     |  admin   |
|  b      |   100      |  150     |  reader  |
|  a      |   300      |  350     |  admin   |
|  b      |   200      |  300     |  reader  |

我想避免使用 UDF

您可以通过三个步骤完成此操作。首先,为范围之间的每个值分解数据。然后,为具有连续范围的每个序列创建 id,然后按用户名和 id 分组并获得最终的 table.

import pyspark.sql.functions as F
from pyspark.sql import Window

data =  [("a", 1, 99, "admin"),
         (  "a", 100, 200, "admin"),
         (  "b", 100, 150, "reader"),
         (  "a", 300, 350, "admin"),
         (  "b", 200, 300, "reader")]

df = spark.createDataFrame(data, 
        schema=["username", "start_range", "end_range", "user_type"])

# explode data between ranges
df1 = (df
   .withColumn("seq", F.sequence("start_range", "end_range"))
   .withColumn("tag", F.explode("seq")))

# tag each group with id 
w = Window.partitionBy("username").orderBy("tag")
w1 = (Window
      .partitionBy("username")
      .orderBy("tag")
      .rangeBetween(Window.unboundedPreceding, 0))

df2 = (df1
   .withColumn("tag_lag", F.lag("tag").over(w))
   .fillna(0)
   .withColumn("diff", F.col("tag")-F.col("tag_lag")-1)
   .withColumn("id", F.sum("diff").over(w1)))

# group by id and get start_range and end_range 
df3 = (df2
        .groupBy("username", "user_type", "id")
        .agg(F.min("tag").alias("start_range"), 
             F.max("tag").alias("end_range"))
      ).select("username", "start_range", "end_range", "user_type")

想法是识别并标记可以继续的行,然后将它们分组。

此解决方案比 Anna 的更有效,因为它不使用爆炸。

import pyspark.sql.functions as F
from pyspark.sql import Window

column_names = ['username', 'start_range', 'end_range', 'user_type']
data = [
    ('a', 1, 99, 'admin'),
    ('a', 100, 200, 'admin'),
    ('b', 100, 150, 'admin'),
    ('a', 300, 350, 'reader'),
    ('b', 200, 300, 'admin'),
]
df = spark.createDataFrame(data=data, schema=column_names)

def merge_by_start_end(df):
    w = Window.orderBy(['username', 'user_type', 'start_range'])
    
    # Used to test whether an entry can be continued
    df = df.withColumn('end_range_lag_p1',
        F.lag(F.col('end_range') + 1, default=-100).over(w))
    
    # Index by cumulative sum, such that entries that can be continued get the same index
    is_entry_terminal = (F.col('end_range_lag_p1') != F.col('start_range')).cast('long')
    df = df.withColumn('merge_index', F.sum(is_entry_terminal).over(w))
    
    return df.groupBy(['username', 'user_type', 'merge_index']).agg(
        F.min('start_range').alias('start_range'),
        F.max('end_range').alias('end_range'),
    ).drop('merge_index')

merge_by_start_end(df)