Pyspark window 函数计算停靠站之间的交通次数

Pyspark window function to calculate number of transits between stops

我正在使用 Pyspark,我想创建一个执行以下操作的函数:

描述列车用户交易的给定数据:

+----+-------------------+-------+---------------+-------------+-------------+
|USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
+------------------------+-------+---------------+-------------+-------------+
|John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
|John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
|John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
|Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
|Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
+----+-------------------+-------+---------------+-------------+-------------+

我想知道在 30 分钟内停止 A-B、B-C 等的组合的次数。

所以,假设用户“John”在 7:27 从“King Cross”站到“White Chapell”,然后在 7:35 从“White Chapell”到“Reaven”。
同时,“Tom”在 7:27 从“King Cross”前往“White Chapell”,然后在 7:32 从“White Chapell”前往“Oxford Circus”。

操作的结果必须是这样的:

+----------------------+-----------------+---------------+-----------+
|          DATE        |   ORIG_STOP     |   DEST_STOP   | NUM_TRANS |
+----------------------+-----------------+---------------+-----------+
|   2021-01-27 07:00:00|  King Cross     | White Chapell |       2   |
|   2021-01-27 07:30:00|  White Chapell  | Reaven        |       1   |              
+----------------------+-----------------+---------------+-----------+

我试过使用window函数,但我无法得到我真正想要的。

您可以尝试运行以下

使用 Spark SQL

在第一个 CTE initial_stop_groups 中,它使用 LEAD 函数确定相关的 ORIGIN 和 DESTINATION 停止和时间。下一个 CTE stop_groups,使用 CASE 表达式和日期函数确定相关的 30 分钟间隔,并过滤掉非组(即没有停留目的地)。然后,最终投影使用分组依据来聚合时间间隔、起点和终点组,以计算结果 NUM_TRANS,其中有相同的 30 分钟间隔。

假设您的数据在 input_df

input_df.createOrReplaceTempView("input_df")

output_df = sparkSession.sql("""
 WITH initial_stop_groups AS (
        SELECT
            DATE as ORIG_DATE,
            LEAD(DATE) OVER (
                PARTITION BY USER,TRANSPORT_ID
                ORDER BY DATE
            ) as STOP_DATE,
            STOP as ORIG_STOP,
            LEAD(STOP) OVER (
                PARTITION BY USER,TRANSPORT_ID
                ORDER BY DATE
            ) as DEST_STOP
        FROM
            input_df
    ),
    stop_groups AS (
        SELECT 
            CAST(CONCAT(
              CAST(ORIG_DATE as DATE),
              ' ',
              hour(ORIG_DATE),
              ':',
              CASE WHEN minute(ORIG_DATE) < 30 THEN '00' ELSE '30' END,
              ':00'
            ) AS TIMESTAMP) as ORIG_TIME,
            CASE WHEN STOP_DATE IS NOT NULL THEN CAST(CONCAT(
              CAST(STOP_DATE as DATE),
              ' ',
              hour(STOP_DATE),
              ':',
              CASE WHEN minute(STOP_DATE) < 30 THEN '00' ELSE '30' END,
              ':00'
            ) AS TIMESTAMP) ELSE NULL END as STOP_TIME,
            ORIG_STOP,
            DEST_STOP
        FROM 
            initial_stop_groups
        WHERE
            DEST_STOP IS NOT NULL
    )
    SELECT
        STOP_TIME as DATE, 
        ORIG_STOP,
        DEST_STOP,
        COUNT(1) as NUM_TRANS
    FROM
        stop_groups
    WHERE
        (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60
        
    GROUP BY
        STOP_TIME, ORIG_STOP, DEST_STOP;
    
""")

output_df.show()
DATE orig_stop dest_stop num_trans
2021-01-27T07:00:00.000Z King Cross White Chapell 2
2021-01-27T07:30:00.000Z White Chapell Reaven 1

View on DB Fiddle

  • CAST((STOP_TIME - ORIG_TIME) as STRING) IN ('0 seconds','30 minutes') 替换为 (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60

使用火花API

实际代码

from pyspark.sql import functions as F
from pyspark.sql import Window

next_stop_window = Window().partitionBy("USER","TRANSPORT_ID").orderBy("DATE")

output_df = (
    input_df.select(
        F.col("DATE").alias("ORIG_DATE"),
        F.lead("DATE").over(next_stop_window).alias("STOP_DATE"),
        F.col("STOP").alias("ORIG_STOP"),
        F.lead("STOP").over(next_stop_window).alias("DEST_STOP"),
    ).where(
        F.col("DEST_STOP").isNotNull()
    ).select(
        F.concat(
            F.col("ORIG_DATE").cast("DATE"),
            F.lit(' '),
            F.hour("ORIG_DATE"),
            F.lit(':'),
            F.when(
                F.minute("ORIG_DATE") < 30, '00'
            ).otherwise('30'),
            F.lit(':00')
        ).cast("TIMESTAMP").alias("ORIG_TIME"),
        F.concat(
            F.col("STOP_DATE").cast("DATE"),
            F.lit(' '),
            F.hour("STOP_DATE"),
            F.lit(':'),
            F.when(
                F.minute("STOP_DATE") < 30, '00'
            ).otherwise('30'),
            F.lit(':00')
        ).cast("TIMESTAMP").alias("STOP_TIME"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP")
    ).where(
        (F.unix_timestamp("STOP_TIME") - F.unix_timestamp("ORIG_TIME")) <= 30*60
        # (F.col("STOP_TIME")-F.col("ORIG_TIME")).cast("STRING").isin(['0 seconds','30 minutes'])
    ).groupBy(
        F.col("STOP_TIME"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP"),
    ).count().select(
        F.col("STOP_TIME").alias("DATE"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP"),
        F.col("count").alias("NUM_TRANS"),
    )
    
)
output_df.show()

DATE orig_stop dest_stop num_trans
2021-01-27T07:00:00.000Z King Cross White Chapell 2
2021-01-27T07:30:00.000Z White Chapell Reaven 1

结果架构

output_df.printSchema()
root
 |-- DATE: timestamp (nullable = true)
 |-- ORIG_STOP: string (nullable = true)
 |-- DEST_STOP: string (nullable = true)
 |-- NUM_TRANS: long (nullable = false)

重现性设置代码

data="""+----+-------------------+-------+---------------+-------------+-------------+
|USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
+------------------------+-------+---------------+-------------+-------------+
|John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
|John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
|John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
|Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
|Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
+----+-------------------+-------+---------------+-------------+-------------+
"""

rows = [ [ pc.strip() for pc in line.strip().split("|")[1:-1]] for line in data.strip().split("\n")[3:-1]]
headers = [pc.strip() for pc in data.strip().split("\n")[1].split("|")[1:-1]]

from pyspark.sql import functions as F
input_df = sparkSession.createDataFrame(rows,schema=headers)
input_df = input_df.withColumn("DATE",F.col("DATE").cast("TIMESTAMP"))


让我知道这是否适合你。