基于pyspark中的列值进行回合

round based on column value in pyspark

我需要根据 reading_precision 值

四舍五入 summary_measure_value
from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql import *

df = spark.createDataFrame(
[(123, 2897402, 43.25, 2),
(124, 2897402, 49.25, 0),
(125, 2897402, 43.25, 2), 
(126, 2897402, 48.75, 0)]
, ['model_id','lab_test_id','summary_measure_value','reading_precision'])


partition_by_reading = [
    "model_id",
    "lab_test_id"
]
df.withColumn(
        "reading_value",
        round(avg("summary_measure_value").over(
                    Window.partitionBy(partition_by_reading))
                ,col("reading_precision"))).show()

我收到类型错误:'Column' 对象不可调用

pyspark round 函数需要 scale/precision 的常量值。创建应用舍入逻辑的自定义 udf 可能更幸运。

我根据您共享的示例完成的测试中包含了以下示例:

udf_round = F.udf(lambda val,precision: round(val,precision))
df.withColumn(
        "reading_value",
        udf_round(F.avg("summary_measure_value").over(
                    Window.partitionBy(partition_by_reading))
                ,F.col("reading_precision"))).show()

注意。 round in the udf above refers to the built-in python round 函数。

结果:

model_id lab_test_id summary_measure_value reading_precision reading_value
125 2897402 43.25 2 43.25
124 2897402 49.25 0 49.0
123 2897402 43.25 2 43.25
126 2897402 48.75 0 49.0

编辑 1:

如果你有错误的数据,下面的 udf 更可靠

@udf
def udf_round(val,precision) -> float:
    try:
        # ensure that value is a float and precision is an integer
        return float(round(float(val),int(precision)))
    except:
        return val # return val if there are any errors

出于好奇,应用了 pandas_udf 迭代多个系列。似乎比我想象的要快

from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql import *
from typing import Iterator, Tuple
from pyspark.sql.functions import struct, col


@pandas_udf('double')
def round_series(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]:
    return(a.round(b) for a,b in iterator)

df1=df.withColumn("reading_value",F.avg("summary_measure_value").over( Window.partitionBy("model_id", "lab_test_id"))).withColumn("reading_value",round_series('summary_measure_value','reading_precision')).show()

+--------+-----------+---------------------+-----------------+-------------+
|model_id|lab_test_id|summary_measure_value|reading_precision|reading_value|
+--------+-----------+---------------------+-----------------+-------------+
|     123|    2897402|                43.25|                2|        43.25|
|     124|    2897402|                49.11|                1|         49.1|
|     125|    2897402|                43.25|                2|        43.25|
|     126|    2897402|                48.75|                0|         49.0|
+--------+-----------+---------------------+-----------------+-------------+