在 PySpark 中的 DataFrame 的某些行上应用操作(其结果取决于整个 DataFrame 中的信息)

Applying an operation (whose result depends on the information in the entire DataFrame) on certain rows of the DataFrame in PySpark

我目前正在努力实现一种 Spark-y 方式来使用 PySpark 进行操作。我有一个具有以下结构的大型 DataFrame(约 500,000 行)

   ID    DATE    CHANGE    POOL
   -----------------------------
 1 ID1   DATE1   CHANGE1   POOL1
 2 ID2   DATE2   CHANGE2   POOL2
 3 ID3   DATE3   CHANGE3   POOL3
 4 ID4   DATE4   CHANGE4   POOL4
 ....

其中 ID 是唯一的,DATE 不一定彼此等距(并且可能重复),CHANGE 可以是任意数量的 float 类型(通常但不总是在 -500.0 到 +500.0 之间),并且 POOLs 可以是空列表 [] 或可变长度的非空列表,具体取决于在特定的行上。我们可以安全地假设 DataFrame 按 DATE.

排序

我想通过以下操作替换 POOL 列中的空列表。我只考虑 W 天的 window 内的事件(例如,W 是 180 天)。

  1. 对于 POOL 为空的行,我注意到该行的 DATE
  2. I select all 来自原始 DataFrame 的行,这些行落在 window 过去的 W 天之内在步骤 1 中提到。
  3. 从这个子集中,我收集了所有 CHANGE 的列表(可能带有 collect_list() 函数)。
  4. 我将此列表分配给与步骤 1 中提到的行相对应的 POOL 列。

我尝试(部分)使用 window 函数实现如下(我可以安全地删除带有 partitionBy() 的行)

POOL_DAYS = 180
days = lambda i: i * 86400
window_glb = Window\
                 .partitionBy()\
                 .orderBy(col('DATE').cast('long'))\
                 .rangeBetween(-days(POOL_DAYS), 0)
df = df.withColumn('POOL', collect_list('CHANGE').over(window_glb))

但它因这个错误而崩溃

Py4JJavaError: An error occurred while calling o1859.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 203.1 failed 4 times, most recent failure: Lost task 0.3 in stage 203.1 (TID 6339, xxxxxxxxxx.com, executor 1):
java.lang.IllegalArgumentException: Size exceeds Integer.MAX_VALUE

对我来说,这看起来像是 运行 内存不足的问题。为了对此进行测试,我在原始 DataFrame 的一个更小的子集上使用了相同的操作,并且它成功完成了。

请注意,我确实理解上面的实现对所有行重复操作而不考虑 POOL 列是否为空,但我首先想看看我是否可以让它工作或不在进行下一步之前;这种方法似乎不适用于大型 DataFrame。

蛮力(根本不是 Spark-y)方法是对 POOL 为空的行执行 for 循环并应用以下常规(不是 UDF)函数在原始 DataFrame

def get_pool(row, window_size):
    end_date = row.DATE
    start_date = end_date - datetime.timedelta(days = window_size)
    return df.where((col('DATE') < end_date) & (col('DATE') >= start_date))\
             .agg(collect_list('CHANGE'))\
             .rdd.flatMap(list).first()

但是,预计需要很长时间才能完成。

更新: 我也尝试过使用 UDF

def get_pool(date, window_size):
    end_date = date
    start_date = end_date - datetime.timedelta(days = window_size)
    return df.where((col('DATE') < end_date) & (col('DATE') >= start_date))\
             .agg(collect_list('CHANGE'))\
             .rdd.collect()
get_pool_udf = udf(get_pool, ArrayType(DoubleType()))

但由于函数的结果取决于整个(或大部分)DataFrame 的信息,因此它会因错误而停止

PicklingError: Could not serialize object: Py4JError: An error occurred while calling o1857.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist

少了第一种基于window函数的方法,暴力法,坏掉的UDF尝试,请问如何实现?我需要能够对从感兴趣的日期开始的特定日期范围内的 DataFrame 子集进行分组和处理;我如何使用 Spark 执行此操作?

我是 Spark(和 PySpark)的新手,很难思考这个问题。我真的很感激任何帮助。谢谢!

我在回答我自己的问题:我 select IDDATEPOOL 列为空,我执行了一个连接两个 DataFrame 中的一个,同时限制在感兴趣的日期范围内的那些行。按 ID 分组并将每个组中的 CHANGE 收集为列表,给出要分配给 POOL 列的所需列表,这些行的 POOL 是空列表:

renamed_df = df.select('DATE', 'CHANGE')\
               .withColumnRenamed('DATE', 'r_DATE')\
               .withColumnRenamed('CHANGE', 'r_CHANGE')

no_pool_df = df.where(size('POOL') == 0)

cross_df = no_pool_df\
            .join(renamed_df,\
                    (renamed_df.r_DATE >= date_sub(no_pool_df.DATE, POOL_DAYS)) &\
                    (renamed_df.r_DATE < no_pool_df.DATE)\
                , 'right')\
            .select('ID', 'r_CHANGE')\
            .groupBy('ID')\
            .agg(collect_list('r_CHANGE'))\
            .withColumnRenamed('ID', 'no_pool__ID')\
            .withColumnRenamed('collect_list(r_CHANGE)', 'no_pool__POOL')

df = df.join(cross_df, cross_df.no_pool__ID == df.ID, 'left_outer')\
       .drop('no_pool__ID')

现在,DataFrame df 有两列一起包含所需的信息,原始 POOL 和新创建的 no_pool__POOL。我使用以下 UDF 将这两个正确地混合到一列中,POOL_mix:

def mix_pools(history, no_history):
    if not history:
        return no_history
    else:
        return history
mix_pools_udf = sfn.udf(mix_pools, ArrayType(DoubleType()))

df = df.withColumn('POOL_mix', mix_pools_udf(col('POOL'), col('no_pool__POOL')))\
       .drop('POOL', 'no_pool__POOL')