Pyspark 将间隔分成子间隔

Pyspark split interval into sub intervals

我有一个包含 3 列的数据框 "from"、"to"、"country",例如:

from to country
1    105 abc
500  1000 def

我想通过将值从值拆分为大小 = 10 来创建数据框。所以我应该将数据框设为

from to country
1    10 abc
11   20 abc
21   30 abc
31   40 abc
...
91   105 abc ( the left out values go in last bucket for that range)
500  510 def

等等...

from pyspark.sql.functions import udf, col, explode, array, struct, length
from pyspark.sql.types import ArrayType, StructType, StructField, IntegerType

#Creating the DataFrame
values = [(1,105,'abc'),(500,1000,'def')]
df = sqlContext.createDataFrame(values,['from','to','country'])

step_size=10
#Creating UDFs below
def make_list_from(start,end):
    return [i for i in list(range(start, end, step_size)) if (end-i) >= (step_size-1)]
make_list_from_udf = udf(make_list_from,ArrayType(IntegerType()))

def make_list_to(start,end):
    right_list=[i+step_size-1 for i in list(range(start, end, step_size)) if (end-i) >= (step_size-1)]
    right_list[len(right_list)-1]=end
    return right_list
make_list_to_udf = udf(make_list_to,ArrayType(IntegerType()))

#Creating Lists of sub-intervals
df = df.withColumn('my_list_from',make_list_from_udf(col('from'),col('to')))\
       .withColumn('my_list_to',make_list_to_udf(col('from'),col('to')))\
       .drop('from','to')
df.show()
+-------+--------------------+--------------------+
|country|        my_list_from|          my_list_to|
+-------+--------------------+--------------------+
|    abc|[1, 11, 21, 31, 4...|[10, 20, 30, 40, ...|
|    def|[500, 510, 520, 5...|[509, 519, 529, 5...|
+-------+--------------------+--------------------+

#Exploding the Lists
zip_ = udf(
  lambda x, y: list(zip(x, y)),
  ArrayType(StructType([
      # Adjust types to reflect data types
      StructField("first", IntegerType()),
      StructField("second", IntegerType())
  ]))
)
df = (df
    .withColumn("tmp", zip_("my_list_from", "my_list_to"))
    # UDF output cannot be directly passed to explode
    .withColumn("tmp", explode("tmp"))
    .select(col("tmp.first").alias("from"), col("tmp.second").alias("to"), "country"))
df.show(100)
+----+----+-------+
|from|  to|country|
+----+----+-------+
|   1|  10|    abc|
|  11|  20|    abc|
|  21|  30|    abc|
|  31|  40|    abc|
|  41|  50|    abc|
|  51|  60|    abc|
|  61|  70|    abc|
|  71|  80|    abc|
|  81|  90|    abc|
|  91| 105|    abc|
| 500| 509|    def|
| 510| 519|    def|
| 520| 529|    def|
.
.
.
| 960| 969|    def|
| 970| 979|    def|
| 980| 989|    def|
| 990|1000|    def|
+----+----+-------+