Pyspark Dataframe - 没有 Numpy 或其他库的中位数

Pyspark Dataframe - Median Without Numpy Or Other Libraries

我已经在 pyspark 中为此工作了一段时间,但我被卡住了。我正在尝试获取列 numbers 对应的 window 的中位数。我需要在不使用其他库(例如 numpy 等)的情况下执行此操作

到目前为止(如下图所示),我已按 id 列将数据集分组为 windows。这由列 row_numbers 描述,它向您展示了每个 window 的样子。此数据框示例中有三个 windows。

这是我想要的:

我希望每一行还包含 id 列的 window 的中位数,而不考虑它自己的行。我需要的中位数位置在我下面的函数中,称为 median_loc

示例:对于 row_number = 5,我需要找到其上方第 1 至 4 行的中位数(即不包括 row_number 5).因此,中位数(根据我的要求)是 id 列的平均值 window 其中 row_number = 1 和 row_number = 2 即

Date        id      numbers row_number  med_loc
2017-03-02  group 1   98        1       [1]
2017-04-01  group 1   50        2       [1]
2018-03-02  group 1   5         3       [1, 2]
2016-03-01  group 2   49        1       [1]
2016-12-22  group 2   81        2       [1]
2017-12-31  group 2   91        3       [1, 2]
2018-08-08  group 2   19        4       [2]
2018-09-25  group 2   52        5       [1, 2]
2017-01-01  group 3   75        1       [1]
2018-12-12  group 3   17        2       [1]

我用来获取最后一列med_loc的代码如下

def median_loc(sz):
    if sz == 1 or sz == 0:
        kth = [1]
        return kth
    elif sz % 2 == 0 and sz > 1:
        szh = sz // 2
        kth = [szh - 1, szh] if szh != 1 else [1, 2]
        return kth
    elif sz % 2 != 0 and sz > 1:
        kth = [(sz + 1) // 2]
        return kth


sqlContext.udf.register("median_location", median_loc)

median_loc = F.udf(median_loc)

df = df.withColumn("med_loc", median_loc(df.row_number)-1)

注意:为了便于理解,我只是将它们看起来像一个列表。它只是为了显示中位数在各自 window 中的位置。这只是为了让在 Stack Overflow

上阅读本文的人们更容易理解

我想要的输出如下:

Date        id      numbers row_number  med_loc     median
2017-03-02  group 1   98        1       [1]           98
2017-04-01  group 1   50        2       [1]           98
2018-03-02  group 1   5         3       [1, 2]        74
2016-03-01  group 2   49        1       [1]           49
2016-12-22  group 2   81        2       [1]           49
2017-12-31  group 2   91        3       [1, 2]        65
2018-08-08  group 2   19        4       [2]           81
2018-09-25  group 2   52        5       [1, 2]        65
2017-01-01  group 3   75        1       [1]           75
2018-12-12  group 3   17        2       [1]           75

基本上,到目前为止获得中位数的方法是这样的:

  1. 如果 med_loc 是一位数字(即,如果列表只有一位数字,例如 [1] 或 [3] 等),则中位数 = df.numbers 其中 df.row_number = df.med_loc

  2. 如果 med_loc 是两位数(即如果列表有两位数,例如 [1,2] 或 [2, 3] 等)则 median = average(df.numbers) 其中 df.row_number 在 df.med_loc

我怎么强调都不为过 使用其他库(如 numpy 等)来获取输出对我来说有多重要。我查看了其他使用 np.median 的解决方案,它们有效,但是,这不是我目前的要求。

很抱歉,如果这个解释过于冗长并且使它复杂化。我已经看了好几天了,似乎无法弄清楚。我也尝试使用 percent_rank 函数,但我无法弄清楚,因为并非所有 windows 都包含 0.5 个百分位数。

我们将不胜感激。

假设您从以下 DataFrame 开始,df:

+----------+-------+-------+
|      Date|     id|numbers|
+----------+-------+-------+
|2017-03-02|group 1|     98|
|2017-04-01|group 1|     50|
|2018-03-02|group 1|      5|
|2016-03-01|group 2|     49|
|2016-12-22|group 2|     81|
|2017-12-31|group 2|     91|
|2018-08-08|group 2|     19|
|2018-09-25|group 2|     52|
|2017-01-01|group 3|     75|
|2018-12-12|group 3|     17|
+----------+-------+-------+

订购数据框

首先像您在示例中所做的那样添加 row_number 并将输出分配给新的 DataFrame df2:

import pyspark.sql.functions as f
from pyspark.sql import Window

df2 = df.select(
    "*", f.row_number().over(Window.partitionBy("id").orderBy("Date")).alias("row_number")
)
df2.show()
#+----------+-------+-------+----------+
#|      Date|     id|numbers|row_number|
#+----------+-------+-------+----------+
#|2017-03-02|group 1|     98|         1|
#|2017-04-01|group 1|     50|         2|
#|2018-03-02|group 1|      5|         3|
#|2016-03-01|group 2|     49|         1|
#|2016-12-22|group 2|     81|         2|
#|2017-12-31|group 2|     91|         3|
#|2018-08-08|group 2|     19|         4|
#|2018-09-25|group 2|     52|         5|
#|2017-01-01|group 3|     75|         1|
#|2018-12-12|group 3|     17|         2|
#+----------+-------+-------+----------+

收集中位数值

现在您可以在 id 列上将 df2 连接到自身,条件是左侧的 row number1 或大于右侧的row_number。然后按左侧 DataFrame 的 ("id", "Date", "row_number") 分组,并将右侧 DataFrame 中的 numbers 收集到列表中。

对于row_number等于1的情况,我们只想保留这个集合列表的第一个元素。否则保留所有数字,但对它们进行排序,因为我们需要对它们进行排序以计算中位数。

调用这个中间数据帧df3:

df3 = df2.alias("l").join(df2.alias("r"), on="id", how="left")\
    .where("l.row_number = 1 OR (r.row_number < l.row_number)")\
    .groupBy("l.id", "l.Date", "l.row_number")\
    .agg(f.collect_list("r.numbers").alias("numbers"))\
    .select(
        "id",
        "Date",
        "row_number",
        f.when(
            f.col("row_number") == 1,
            f.array([f.col("numbers").getItem(0)])
        ).otherwise(f.sort_array("numbers")).alias("numbers")
    )
df3.show()
#+-------+----------+----------+----------------+
#|     id|      Date|row_number|         numbers|
#+-------+----------+----------+----------------+
#|group 1|2017-03-02|         1|            [98]|
#|group 1|2017-04-01|         2|            [98]|
#|group 1|2018-03-02|         3|        [50, 98]|
#|group 2|2016-03-01|         1|            [49]|
#|group 2|2016-12-22|         2|            [49]|
#|group 2|2017-12-31|         3|        [49, 81]|
#|group 2|2018-08-08|         4|    [49, 81, 91]|
#|group 2|2018-09-25|         5|[19, 49, 81, 91]|
#|group 3|2017-01-01|         1|            [75]|
#|group 3|2018-12-12|         2|            [75]|
#+-------+----------+----------+----------------+

请注意 df3numbers 列有一个列表,其中列出了我们要为其找到中位数的适当值。

计算中位数

由于您的 Spark 版本高于 2.1,您可以使用 pyspark.sql.functions.posexplode() 计算此值列表的中位数。对于较低版本的 spark,您需要使用 udf.

首先在 df3 中创建 2 个辅助列:

  • isEven:一个布尔值,用于指示 numbers 数组是否具有偶数个元素
  • middle:数组中间的索引,也就是长度/2的底

创建这些列后,使用 posexplode() 分解数组,这将 return 两个新列:poscol。然后我们过滤掉生成的 DataFrame 以仅保留我们需要计算中位数的位置。

持仓逻辑如下:

  • 如果isEvenFalse,我们只保留中间位置
  • 如果isEvenTrue,我们保留中间位置和中间位置 - 1.

最后按 idDate 分组并对剩余的 numbers.

进行平均
df3.select(
    "*",
    f.when(
        (f.size("numbers") % 2) == 0,
        f.lit(True)
    ).otherwise(f.lit(False)).alias("isEven"),
    f.floor(f.size("numbers")/2).alias("middle")
).select(
        "id", 
        "Date",
        "middle",
        f.posexplode("numbers")
).where(
    "(isEven=False AND middle=pos) OR (isEven=True AND pos BETWEEN middle-1 AND middle)"
).groupby("id", "Date").agg(f.avg("col").alias("median")).show()
#+-------+----------+------+
#|     id|      Date|median|
#+-------+----------+------+
#|group 1|2017-03-02|  98.0|
#|group 1|2017-04-01|  98.0|
#|group 1|2018-03-02|  74.0|
#|group 2|2016-03-01|  49.0|
#|group 2|2016-12-22|  49.0|
#|group 2|2017-12-31|  65.0|
#|group 2|2018-08-08|  81.0|
#|group 2|2018-09-25|  65.0|
#|group 3|2017-01-01|  75.0|
#|group 3|2018-12-12|  75.0|
#+-------+----------+------+