根据pyspark中一列过滤条件和window周期计算平均值

Calculate average based on the filtered condition in a column and window period in pyspark

我有一个 pyspark 数据框:

date    | cust  |  amount |  is_delinquent
---------------------------------------
1/1/20  |  A    |  5      |      0
13/1/20 |  A    |  1      |      0
15/1/20 |  A    |  3      |      1
19/1/20 |  A    |  4      |      0
20/1/20 |  A    |  4      |      1
27/1/20 |  A    |  2      |      0
1/2/20  |  A    |  2      |      0
5/2/20  |  A    |  1      |      0
1/1/20  |  B    |  7      |      0
1/1/20  |  B    |  5      |      0

现在我想计算 amount 在 30 天期间 windows 的平均值,并过滤列 IS_DELINQUENT 等于 0。当 IS_DELINQUENT 等于 1 时应跳过并替换为 NaN。

我预期的最终数据帧是:

date    | cust  |  amount |  is_delinquent |   avg_amount
----------------------------------------------------------
1/1/20  |  A    |  5      |      0         |      null
13/1/20 |  A    |  1      |      0         |      5
15/1/20 |  A    |  3      |      1         |      null
19/1/20 |  A    |  4      |      0         |      3
20/1/20 |  A    |  4      |      1         |      null
27/1/20 |  A    |  2      |      0         |      3.333
1/2/20  |  A    |  2      |      0         |      null
5/2/20  |  A    |  1      |      0         |      2
1/1/20  |  B    |  7      |      0         |      null
9/1/20  |  B    |  5      |      0         |      7

没有过滤,我的代码会是这样的:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

days = lambda i: i * 86400
w_pay_30x = Window.partitionBy("cust").orderBy(col("date").cast("timestamp").cast("long")).rangeBetween(-days(30), -days(1))

data.withColumn("avg_amount", F.avg("amount").over(w_pay_30x)

知道如何添加此过滤器吗?

仅当 is_delinquent 等于 0 时,您才可以使用 when 计算和显示平均值。此外,您可能希望在 partition by 子句中包含月份 window.

from pyspark.sql import functions as F, Window

days = lambda i: i * 86400
w_pay_30x = (Window.partitionBy("cust", F.month(F.to_timestamp('date', 'd/M/yy')))
                   .orderBy(F.to_timestamp('date', 'd/M/yy').cast('long'))
                   .rangeBetween(-days(30), -days(1))
            )

data2 = data.withColumn(
    'avg_amount',
    F.when(
        F.col('is_delinquent') == 0, 
        F.avg(
            F.when(
                F.col('is_delinquent') == 0, 
                F.col('amount')
            )
        ).over(w_pay_30x)
    )
).orderBy('cust', F.to_timestamp('date', 'd/M/yy'))

data2.show()
+-------+----+------+-------------+------------------+
|   date|cust|amount|is_delinquent|        avg_amount|
+-------+----+------+-------------+------------------+
| 1/1/20|   A|     5|            0|              null|
|13/1/20|   A|     1|            0|               5.0|
|15/1/20|   A|     3|            1|              null|
|19/1/20|   A|     4|            0|               3.0|
|20/1/20|   A|     4|            1|              null|
|27/1/20|   A|     2|            0|3.3333333333333335|
| 1/2/20|   A|     2|            0|              null|
| 5/2/20|   A|     1|            0|               2.0|
| 1/1/20|   B|     7|            0|              null|
| 9/1/20|   B|     5|            0|               7.0|
+-------+----+------+-------------+------------------+