Pyspark - 将具有最大值的列转换为单独的 1 和 0 条目

Pyspark - Transform columns with maximum values into separate 1 and 0 entries

我在 pandas 中有针对此问题的工作版本,但我无法将其转换为 pyspark。

我的 input DataFrame 如下所示:

test_df = pd.DataFrame({
    'id': [1],
    'cat_1': [2],
    'cat_2': [2],
    'cat_3': [1]
})
test_df_spark = spark.createDataFrame(test_df)
test_df_spark.show()

+---+-----+-----+-----+
| id|cat_1|cat_2|cat_3|
+---+-----+-----+-----+
|  1|    2|    2|    1| <- non-maximum
+---+-----+-----+-----+
         ^     ^
         |     |
     maximum maximum    

我愿意:

  1. 获取 cat_1、cat_2、cat_3 中具有最大值的列(1 或更多)。在示例中,这些将是 cat_1 和 cat_2.
  2. 这些列的值应为 1。其余非最大列将设置为 0。
  3. 具有 1 个值的列应分成单独的行。

结果 DataFrame 应该如下所示:

+---+-----+-----+-----+
| id|cat_1|cat_2|cat_3|
+---+-----+-----+-----+
|  1|    1|    0|    0|
|  1|    0|    1|    0|
+---+-----+-----+-----+

目前,我能弄清楚的最多的是如何根据它们的值(无论是否为最大值)将列设置为 1 或 0,但是我仍然不知道如何生成其他条目:

columns = ['cat_1', 'cat_2', 'cat_3']
(
    test_df_spark
    .withColumn(
        'max_value',
        F.greatest(
            *columns
        )
    )
    .select(
        'id',
        *[F.when(F.col(c) == F.col('max_value'), F.lit(1)).otherwise(F.lit(0)).alias(c) for c in columns]
    )
    .show()
)

+---+-----+-----+-----+
| id|cat_1|cat_2|cat_3|
+---+-----+-----+-----+
|  1|    1|    1|    0|
+---+-----+-----+-----+

提前致谢!

假设你现在的结果是df1:

columns = ['cat_1', 'cat_2', 'cat_3']
df1 = (
    test_df_spark
    .withColumn(
        'max_value',
        F.greatest(
            *columns
        )
    )
    .select(
        'id',
        *[F.when(F.col(c) == F.col('max_value'), F.lit(1)).otherwise(F.lit(0)).alias(c) for c in columns]
    )
)

您可以通过创建结构数组来操纵 df1 以获得您想要的结果,inline 它:

df2 = df1.select(
    'id', 
    F.array(*[
        F.when(
            F.col(c1) == 1, 
            F.struct(*[
                F.lit(1).alias(c2) if i1 == i2 else F.lit(0).alias(c2) 
                for i2, c2 in enumerate(columns)
            ])
        ) 
        for i1, c1 in enumerate(columns)
    ]).alias('cat')
).selectExpr(
    'id', 
    'inline(filter(cat, x -> x is not null))'
)

df2.show()
+---+-----+-----+-----+
| id|cat_1|cat_2|cat_3|
+---+-----+-----+-----+
|  1|    1|    0|    0|
|  1|    0|    1|    0|
+---+-----+-----+-----+