如何创建 spark udf 以将 float 插值到 INT 以及如何编写比我所做的更好的逻辑
How can I create spark udf for interpolation of float to INT and how can I write better logic than I have done
下面是我的 Spark Dataframe 我想做插值并为此编写一个 Spark UDF 我不确定如何编写更好的逻辑并从上面创建一个 UDF
这是为了转换 Position_float 并将其插入整数以将位置转换为适当的整数值
def dirty_fill(df, id_col, y_cols):
from pyspark.sql import types as T
df = df.withColumn('position_plus', (df.position_float + 0.5).cast(T.IntegerType()))
df = df.withColumn('position_minus', (df.position_float - 0.5).cast(T.IntegerType()))
df = df.withColumn('position', df.position_float.cast(T.IntegerType()))
df1 = df.select([id_col, 'position_plus'] + y_cols).withColumnRenamed('position_plus', 'position')
df2 = df.select([id_col, 'position_minus'] + y_cols).withColumnRenamed('position_minus', 'position')
df3 = df.select([id_col, 'position'] + y_cols)
df123 = df1.union(df2).union(df3).sort([id_col, 'position']).dropDuplicates([id_col, 'position'])
return df123
y_cols = ['entry_temperature']
finish_mill_entry_filled = dirty_fill(finish_mill_entry, 'finish_mill_id', y_cols)
这是我的数据框样本
| Finishing_mill_id | Sample | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529 | 1 | 0.000000 | 1986.0 |
| 2015418529 | 2 | 2.192982 | 1997.0 |
| 2015418529 | 3 | 4.385965 | 2003.0 |
| 2018171498 | 445 | 495.535714 | 1643.0 |
| 2018171498 | 446 | 496.651786 | 1734.0 |
| 2018171498 | 447 | 497.767857 | 1748.0 |
| 2018171498 | 448 | 498.883929 | 1755.0 |
我需要将浮点数插入整数
我要的是
| Finishing_mill_id | Sample | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529 | 1 | 0 | 1986.0 |
| 2015418529 | 2 | 1 | 1986 |
| 2015418529 | 3 | 2 | 1997.0 |
| 2015418529 | 4 | 3 | 1997 |
| 2015418529 | 5 | 4 | 2003.0 |
| 2018171498 | 445 | 496 | 1643.0 |
| 2018171498 | 446 | 497 | 1734.0 |
| 2018171498 | 447 | 498 | 1748.0 |
| 2018171498 | 448 | 499 | 1755.0 |
我需要一个 spark user_defined 函数来执行此操作,并且不应遗漏任何数据点,因为我 Position_float 在 0-500 范围内,我还需要注意有每一点都不会遗漏任何一点。需要以适当的方式修改我的插值逻辑
为了说清楚我有立场
0.000
2.19,但我没有 datapaint,但我需要什么,我需要有 1.00 的位置。我需要位置 1.00 的值,即使数据不存在那种线性插值。我希望它有所帮助
只需使用 round
并键入转换为 IntegerType
from pyspark.sql import functions as F
from pyspark.sql import types as T
df = df.withColumn('Position_float', F.round(F.col('Position_float')).cast(T.IntegerType()))
1. Window 函数
您可以使用 window 函数来填补空白并插入值。
让我们从示例数据框开始:
import pyspark.sql.functions as psf
import pyspark.sql.types as pst
from pyspark.sql import Window
import numpy as np
df = spark.createDataFrame(
[[float(t)/10., float(v)] for t, v in zip(np.random.randint(0, 1000, 20), np.random.randint(100, 200, 20))],
schema=pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position', 'value']])) \
.withColumn('position_round', psf.round('position'))
+--------+-----+--------------+
|position|value|position_round|
+--------+-----+--------------+
| 68.5|121.0| 69.0|
| 76.3|126.0| 76.0|
| 88.3|150.0| 88.0|
| 59.0|197.0| 59.0|
| 20.7|119.0| 21.0|
| 0.1|167.0| 0.0|
| 20.1|177.0| 20.0|
| 81.9|199.0| 82.0|
| 63.6|163.0| 64.0|
| 32.4|115.0| 32.0|
| 43.6|130.0| 44.0|
| 11.9|175.0| 12.0|
| 68.2|176.0| 68.0|
| 28.9|184.0| 29.0|
| 46.3|199.0| 46.0|
| 9.7|155.0| 10.0|
| 57.8|163.0| 58.0|
| 83.6|173.0| 84.0|
| 16.2|169.0| 16.0|
| 87.1|127.0| 87.0|
+--------+-----+--------------+
为了填补空白,我们将创建一个整数范围:
start, end = list(df.agg(psf.min('position_round'), psf.max('position_round')).collect()[0])
pos_df = spark.range(start=start, end=end, step=1) \
.withColumnRenamed('id', 'position_round')
现在我们可以加入两个数据框:
w1 = Window.orderBy('position_round')
w2 = Window.partitionBy('group').orderBy('position_round')
df_resample = df \
.select(
'*',
psf.lead('position_round', 1).over(w1).alias('next_position'),
psf.lead('value', 1).over(w1).alias('next_value')) \
.join(pos_df, on='position_round', how='right') \
.withColumn('group', psf.sum((~psf.isnull('position')).cast('int')).over(w1)) \
.select(
'*',
(psf.row_number().over(w2) - 1).alias('i'),
psf.first(psf.col('next_position') - psf.col('position_round')).over(w2).alias('dx'),
psf.first('value').over(w2).alias('value0'),
psf.first(psf.col('next_value') - psf.col('value')).over(w2).alias('dy')) \
.withColumn(
'value_round',
psf.when((psf.col('dx') > 0) | psf.isnull('next_value'), psf.col('value0') + psf.col('i') * psf.col('dy') / psf.col('dx')) \
.otherwise(psf.col('value')))
- 第一个 window 函数是存储
next_value
和 next_position
以便稍后能够计算我们的 dx
和 dy
- 然后我们需要用不同的
group
id 来识别每个间隙,以便我们可以为每个不同的线性段插入值
- 最后但同样重要的是,我们汇集了我们需要的所有元素:
- 间隙长度:
dx
- 数值增量:
dy
- 间隙中的当前行索引
i
我们现在可以计算 value_round
,value
在位置 position_round
的插值
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
|position_round|position|value|next_position|next_value|group| i| dx|value0| dy|value_round|
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
| 0| 0.1|167.0| 10.0| 155.0| 1| 0|10.0| 167.0|-12.0| 167.0|
| 1| null| null| null| null| 1| 1|10.0| 167.0|-12.0| 165.8|
| 2| null| null| null| null| 1| 2|10.0| 167.0|-12.0| 164.6|
| 3| null| null| null| null| 1| 3|10.0| 167.0|-12.0| 163.4|
| 4| null| null| null| null| 1| 4|10.0| 167.0|-12.0| 162.2|
| 5| null| null| null| null| 1| 5|10.0| 167.0|-12.0| 161.0|
| 6| null| null| null| null| 1| 6|10.0| 167.0|-12.0| 159.8|
| 7| null| null| null| null| 1| 7|10.0| 167.0|-12.0| 158.6|
| 8| null| null| null| null| 1| 8|10.0| 167.0|-12.0| 157.4|
| 9| null| null| null| null| 1| 9|10.0| 167.0|-12.0| 156.2|
| 10| 9.7|155.0| 12.0| 175.0| 2| 0| 2.0| 155.0| 20.0| 155.0|
| 11| null| null| null| null| 2| 1| 2.0| 155.0| 20.0| 165.0|
| 12| 11.9|175.0| 16.0| 169.0| 3| 0| 4.0| 175.0| -6.0| 175.0|
| 13| null| null| null| null| 3| 1| 4.0| 175.0| -6.0| 173.5|
| 14| null| null| null| null| 3| 2| 4.0| 175.0| -6.0| 172.0|
| 15| null| null| null| null| 3| 3| 4.0| 175.0| -6.0| 170.5|
| 16| 16.2|169.0| 20.0| 177.0| 4| 0| 4.0| 169.0| 8.0| 169.0|
| 17| null| null| null| null| 4| 1| 4.0| 169.0| 8.0| 171.0|
| 18| null| null| null| null| 4| 2| 4.0| 169.0| 8.0| 173.0|
| 19| null| null| null| null| 4| 3| 4.0| 169.0| 8.0| 175.0|
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
2。 UDF
如果你不想使用 window 函数你可以写一个 UDF
在 python
中做插值然后 return 一个数组(位置, 值)元组:
def interpolate(pos, next_pos, value, next_value):
if pos == next_pos or next_value is None:
return [(pos, value)]
return [[pos + i, value + i * (next_value - value) / (next_pos - pos)] for i in range(int(next_pos - pos))]
interpolate_udf = psf.udf(interpolate, pst.ArrayType(pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position_round', 'value_round']])))
请注意,元组是 StructType
类型,以便更容易 "flatten" 将元组放入列中。
w1 = Window.orderBy('position_round')
df_udf = df \
.select(
'*',
psf.lead('position_round', 1).over(w1).alias('next_position'),
psf.lead('value', 1).over(w1).alias('next_value')) \
.withColumn('tmp', psf.explode(interpolate_udf('position_round', 'next_position', 'value', 'next_value'))) \
.select('*', 'tmp.*').drop('tmp')
这是我们得到的:
+--------+-----+--------------+-------------+----------+--------------+----------+
|position|value|position_round|next_position|next_value|position_round|value_round|
+--------+-----+--------------+-------------+----------+--------------+----------+
| 0.1|167.0| 0.0| 10.0| 155.0| 0.0| 167.0|
| 0.1|167.0| 0.0| 10.0| 155.0| 1.0| 165.8|
| 0.1|167.0| 0.0| 10.0| 155.0| 2.0| 164.6|
| 0.1|167.0| 0.0| 10.0| 155.0| 3.0| 163.4|
| 0.1|167.0| 0.0| 10.0| 155.0| 4.0| 162.2|
| 0.1|167.0| 0.0| 10.0| 155.0| 5.0| 161.0|
| 0.1|167.0| 0.0| 10.0| 155.0| 6.0| 159.8|
| 0.1|167.0| 0.0| 10.0| 155.0| 7.0| 158.6|
| 0.1|167.0| 0.0| 10.0| 155.0| 8.0| 157.4|
| 0.1|167.0| 0.0| 10.0| 155.0| 9.0| 156.2|
| 9.7|155.0| 10.0| 12.0| 175.0| 10.0| 155.0|
| 9.7|155.0| 10.0| 12.0| 175.0| 11.0| 165.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 12.0| 175.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 13.0| 173.5|
| 11.9|175.0| 12.0| 16.0| 169.0| 14.0| 172.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 15.0| 170.5|
| 16.2|169.0| 16.0| 20.0| 177.0| 16.0| 169.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 17.0| 171.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 18.0| 173.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 19.0| 175.0|
+--------+-----+--------------+-------------+----------+--------------+----------+
下面是我的 Spark Dataframe 我想做插值并为此编写一个 Spark UDF 我不确定如何编写更好的逻辑并从上面创建一个 UDF
这是为了转换 Position_float 并将其插入整数以将位置转换为适当的整数值
def dirty_fill(df, id_col, y_cols):
from pyspark.sql import types as T
df = df.withColumn('position_plus', (df.position_float + 0.5).cast(T.IntegerType()))
df = df.withColumn('position_minus', (df.position_float - 0.5).cast(T.IntegerType()))
df = df.withColumn('position', df.position_float.cast(T.IntegerType()))
df1 = df.select([id_col, 'position_plus'] + y_cols).withColumnRenamed('position_plus', 'position')
df2 = df.select([id_col, 'position_minus'] + y_cols).withColumnRenamed('position_minus', 'position')
df3 = df.select([id_col, 'position'] + y_cols)
df123 = df1.union(df2).union(df3).sort([id_col, 'position']).dropDuplicates([id_col, 'position'])
return df123
y_cols = ['entry_temperature']
finish_mill_entry_filled = dirty_fill(finish_mill_entry, 'finish_mill_id', y_cols)
这是我的数据框样本
| Finishing_mill_id | Sample | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529 | 1 | 0.000000 | 1986.0 |
| 2015418529 | 2 | 2.192982 | 1997.0 |
| 2015418529 | 3 | 4.385965 | 2003.0 |
| 2018171498 | 445 | 495.535714 | 1643.0 |
| 2018171498 | 446 | 496.651786 | 1734.0 |
| 2018171498 | 447 | 497.767857 | 1748.0 |
| 2018171498 | 448 | 498.883929 | 1755.0 |
我需要将浮点数插入整数
我要的是
| Finishing_mill_id | Sample | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529 | 1 | 0 | 1986.0 |
| 2015418529 | 2 | 1 | 1986 |
| 2015418529 | 3 | 2 | 1997.0 |
| 2015418529 | 4 | 3 | 1997 |
| 2015418529 | 5 | 4 | 2003.0 |
| 2018171498 | 445 | 496 | 1643.0 |
| 2018171498 | 446 | 497 | 1734.0 |
| 2018171498 | 447 | 498 | 1748.0 |
| 2018171498 | 448 | 499 | 1755.0 |
我需要一个 spark user_defined 函数来执行此操作,并且不应遗漏任何数据点,因为我 Position_float 在 0-500 范围内,我还需要注意有每一点都不会遗漏任何一点。需要以适当的方式修改我的插值逻辑
为了说清楚我有立场 0.000 2.19,但我没有 datapaint,但我需要什么,我需要有 1.00 的位置。我需要位置 1.00 的值,即使数据不存在那种线性插值。我希望它有所帮助
只需使用 round
并键入转换为 IntegerType
from pyspark.sql import functions as F
from pyspark.sql import types as T
df = df.withColumn('Position_float', F.round(F.col('Position_float')).cast(T.IntegerType()))
1. Window 函数
您可以使用 window 函数来填补空白并插入值。
让我们从示例数据框开始:
import pyspark.sql.functions as psf
import pyspark.sql.types as pst
from pyspark.sql import Window
import numpy as np
df = spark.createDataFrame(
[[float(t)/10., float(v)] for t, v in zip(np.random.randint(0, 1000, 20), np.random.randint(100, 200, 20))],
schema=pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position', 'value']])) \
.withColumn('position_round', psf.round('position'))
+--------+-----+--------------+
|position|value|position_round|
+--------+-----+--------------+
| 68.5|121.0| 69.0|
| 76.3|126.0| 76.0|
| 88.3|150.0| 88.0|
| 59.0|197.0| 59.0|
| 20.7|119.0| 21.0|
| 0.1|167.0| 0.0|
| 20.1|177.0| 20.0|
| 81.9|199.0| 82.0|
| 63.6|163.0| 64.0|
| 32.4|115.0| 32.0|
| 43.6|130.0| 44.0|
| 11.9|175.0| 12.0|
| 68.2|176.0| 68.0|
| 28.9|184.0| 29.0|
| 46.3|199.0| 46.0|
| 9.7|155.0| 10.0|
| 57.8|163.0| 58.0|
| 83.6|173.0| 84.0|
| 16.2|169.0| 16.0|
| 87.1|127.0| 87.0|
+--------+-----+--------------+
为了填补空白,我们将创建一个整数范围:
start, end = list(df.agg(psf.min('position_round'), psf.max('position_round')).collect()[0])
pos_df = spark.range(start=start, end=end, step=1) \
.withColumnRenamed('id', 'position_round')
现在我们可以加入两个数据框:
w1 = Window.orderBy('position_round')
w2 = Window.partitionBy('group').orderBy('position_round')
df_resample = df \
.select(
'*',
psf.lead('position_round', 1).over(w1).alias('next_position'),
psf.lead('value', 1).over(w1).alias('next_value')) \
.join(pos_df, on='position_round', how='right') \
.withColumn('group', psf.sum((~psf.isnull('position')).cast('int')).over(w1)) \
.select(
'*',
(psf.row_number().over(w2) - 1).alias('i'),
psf.first(psf.col('next_position') - psf.col('position_round')).over(w2).alias('dx'),
psf.first('value').over(w2).alias('value0'),
psf.first(psf.col('next_value') - psf.col('value')).over(w2).alias('dy')) \
.withColumn(
'value_round',
psf.when((psf.col('dx') > 0) | psf.isnull('next_value'), psf.col('value0') + psf.col('i') * psf.col('dy') / psf.col('dx')) \
.otherwise(psf.col('value')))
- 第一个 window 函数是存储
next_value
和next_position
以便稍后能够计算我们的dx
和dy
- 然后我们需要用不同的
group
id 来识别每个间隙,以便我们可以为每个不同的线性段插入值 - 最后但同样重要的是,我们汇集了我们需要的所有元素:
- 间隙长度:
dx
- 数值增量:
dy
- 间隙中的当前行索引
i
- 间隙长度:
我们现在可以计算 value_round
,value
在位置 position_round
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
|position_round|position|value|next_position|next_value|group| i| dx|value0| dy|value_round|
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
| 0| 0.1|167.0| 10.0| 155.0| 1| 0|10.0| 167.0|-12.0| 167.0|
| 1| null| null| null| null| 1| 1|10.0| 167.0|-12.0| 165.8|
| 2| null| null| null| null| 1| 2|10.0| 167.0|-12.0| 164.6|
| 3| null| null| null| null| 1| 3|10.0| 167.0|-12.0| 163.4|
| 4| null| null| null| null| 1| 4|10.0| 167.0|-12.0| 162.2|
| 5| null| null| null| null| 1| 5|10.0| 167.0|-12.0| 161.0|
| 6| null| null| null| null| 1| 6|10.0| 167.0|-12.0| 159.8|
| 7| null| null| null| null| 1| 7|10.0| 167.0|-12.0| 158.6|
| 8| null| null| null| null| 1| 8|10.0| 167.0|-12.0| 157.4|
| 9| null| null| null| null| 1| 9|10.0| 167.0|-12.0| 156.2|
| 10| 9.7|155.0| 12.0| 175.0| 2| 0| 2.0| 155.0| 20.0| 155.0|
| 11| null| null| null| null| 2| 1| 2.0| 155.0| 20.0| 165.0|
| 12| 11.9|175.0| 16.0| 169.0| 3| 0| 4.0| 175.0| -6.0| 175.0|
| 13| null| null| null| null| 3| 1| 4.0| 175.0| -6.0| 173.5|
| 14| null| null| null| null| 3| 2| 4.0| 175.0| -6.0| 172.0|
| 15| null| null| null| null| 3| 3| 4.0| 175.0| -6.0| 170.5|
| 16| 16.2|169.0| 20.0| 177.0| 4| 0| 4.0| 169.0| 8.0| 169.0|
| 17| null| null| null| null| 4| 1| 4.0| 169.0| 8.0| 171.0|
| 18| null| null| null| null| 4| 2| 4.0| 169.0| 8.0| 173.0|
| 19| null| null| null| null| 4| 3| 4.0| 169.0| 8.0| 175.0|
+--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
2。 UDF
如果你不想使用 window 函数你可以写一个 UDF
在 python
中做插值然后 return 一个数组(位置, 值)元组:
def interpolate(pos, next_pos, value, next_value):
if pos == next_pos or next_value is None:
return [(pos, value)]
return [[pos + i, value + i * (next_value - value) / (next_pos - pos)] for i in range(int(next_pos - pos))]
interpolate_udf = psf.udf(interpolate, pst.ArrayType(pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position_round', 'value_round']])))
请注意,元组是 StructType
类型,以便更容易 "flatten" 将元组放入列中。
w1 = Window.orderBy('position_round')
df_udf = df \
.select(
'*',
psf.lead('position_round', 1).over(w1).alias('next_position'),
psf.lead('value', 1).over(w1).alias('next_value')) \
.withColumn('tmp', psf.explode(interpolate_udf('position_round', 'next_position', 'value', 'next_value'))) \
.select('*', 'tmp.*').drop('tmp')
这是我们得到的:
+--------+-----+--------------+-------------+----------+--------------+----------+
|position|value|position_round|next_position|next_value|position_round|value_round|
+--------+-----+--------------+-------------+----------+--------------+----------+
| 0.1|167.0| 0.0| 10.0| 155.0| 0.0| 167.0|
| 0.1|167.0| 0.0| 10.0| 155.0| 1.0| 165.8|
| 0.1|167.0| 0.0| 10.0| 155.0| 2.0| 164.6|
| 0.1|167.0| 0.0| 10.0| 155.0| 3.0| 163.4|
| 0.1|167.0| 0.0| 10.0| 155.0| 4.0| 162.2|
| 0.1|167.0| 0.0| 10.0| 155.0| 5.0| 161.0|
| 0.1|167.0| 0.0| 10.0| 155.0| 6.0| 159.8|
| 0.1|167.0| 0.0| 10.0| 155.0| 7.0| 158.6|
| 0.1|167.0| 0.0| 10.0| 155.0| 8.0| 157.4|
| 0.1|167.0| 0.0| 10.0| 155.0| 9.0| 156.2|
| 9.7|155.0| 10.0| 12.0| 175.0| 10.0| 155.0|
| 9.7|155.0| 10.0| 12.0| 175.0| 11.0| 165.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 12.0| 175.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 13.0| 173.5|
| 11.9|175.0| 12.0| 16.0| 169.0| 14.0| 172.0|
| 11.9|175.0| 12.0| 16.0| 169.0| 15.0| 170.5|
| 16.2|169.0| 16.0| 20.0| 177.0| 16.0| 169.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 17.0| 171.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 18.0| 173.0|
| 16.2|169.0| 16.0| 20.0| 177.0| 19.0| 175.0|
+--------+-----+--------------+-------------+----------+--------------+----------+