使用别名旋转和聚合 PySpark 数据框
Pivot and aggregate a PySpark Data Frame with alias
我有一个与此类似的 PySpark DataFrame:
df = sc.parallelize([
("c1", "A", 3.4, 0.4, 3.5),
("c1", "B", 9.6, 0.0, 0.0),
("c1", "A", 2.8, 0.4, 0.3),
("c1", "B", 5.4, 0.2, 0.11),
("c2", "A", 0.0, 9.7, 0.3),
("c2", "B", 9.6, 8.6, 0.1),
("c2", "A", 7.3, 9.1, 7.0),
("c2", "B", 0.7, 6.4, 4.3)
]).toDF(["user_id", "type", "d1", 'd2', 'd3'])
df.show()
给出:
+-------+----+---+---+----+
|user_id|type| d1| d2| d3|
+-------+----+---+---+----+
| c1| A|3.4|0.4| 3.5|
| c1| B|9.6|0.0| 0.0|
| c1| A|2.8|0.4| 0.3|
| c1| B|5.4|0.2|0.11|
| c2| A|0.0|9.7| 0.3|
| c2| B|9.6|8.6| 0.1|
| c2| A|7.3|9.1| 7.0|
| c2| B|0.7|6.4| 4.3|
+-------+----+---+---+----+
我通过 type
列对它进行了透视,将结果与 sum()
:
聚合
data_wide = df.groupBy('user_id')\
.pivot('type').sum()
data_wide.show()
给出:
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
|user_id| A_sum(`d1`)| A_sum(`d2`)|A_sum(`d3`)| B_sum(`d1`)|B_sum(`d2`)| B_sum(`d3`)|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
| c2| 7.3|18.799999999999997| 7.3|10.299999999999999| 15.0|4.3999999999999995|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
现在,生成的列名包含 `(波浪号)字符,例如,在向量汇编器中引入这个新列是一个问题,因为它 returns syntax error in attribute name
.出于这个原因,我需要重命名列名,但在循环内或 reduce(lambda...)
函数内调用 withColumnRenamed
方法会花费很多时间(实际上我的 df 有 11.520 列)。
有什么方法可以在数据透视+聚合步骤中避免使用这个字符,或者递归分配一个依赖于新数据透视列名称的别名吗?
提前致谢
您可以使用alias
:
在pivot
的聚合中进行重命名
import pyspark.sql.functions as f
data_wide = df.groupBy('user_id')\
.pivot('type')\
.agg(*[f.sum(x).alias(x) for x in df.columns if x not in {"user_id", "type"}])
data_wide.show()
#+-------+-----------------+------------------+----+------------------+----+------------------+
#|user_id| A_d1| A_d2|A_d3| B_d1|B_d2| B_d3|
#+-------+-----------------+------------------+----+------------------+----+------------------+
#| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
#| c2| 7.3|18.799999999999997| 7.3|10.299999999999999|15.0|4.3999999999999995|
#+-------+-----------------+------------------+----+------------------+----+------------------+
但是,这实际上与执行 pivot
并在之后重命名没有什么不同。这是此方法的执行计划:
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)
#AS `d2`#170, A, B, 0, 0), pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`) AS `d2`#170, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
与中的方法比较:
import re
def clean_names(df):
p = re.compile("^(\w+?)_([a-z]+)\((\w+)\)(?:\(\))?")
return df.toDF(*[p.sub(r"_", c) for c in df.columns])
pivoted = df.groupBy('user_id').pivot('type').sum()
clean_names(pivoted).explain()
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
您会发现两者几乎完全相同。通过避免使用正则表达式,您可能会获得一些微小的加速,但与 pivot
.
相比可以忽略不计
编写了一个简单快速的函数来重命名 PySpark 数据透视表。享受! :)
# This function efficiently rename pivot tables' urgly names
def rename_pivot_cols(rename_df, remove_agg):
"""change spark pivot table's default ugly column names at ease.
Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`.
Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2`
"""
for column in rename_df.columns:
if remove_agg == True:
start_index = column.find('(')
end_index = column.find(')')
if (start_index > 0 and end_index > 0):
rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1])
else:
new_column = column.replace('(','_').replace(')','')
rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1])
return rename_df
我有一个与此类似的 PySpark DataFrame:
df = sc.parallelize([
("c1", "A", 3.4, 0.4, 3.5),
("c1", "B", 9.6, 0.0, 0.0),
("c1", "A", 2.8, 0.4, 0.3),
("c1", "B", 5.4, 0.2, 0.11),
("c2", "A", 0.0, 9.7, 0.3),
("c2", "B", 9.6, 8.6, 0.1),
("c2", "A", 7.3, 9.1, 7.0),
("c2", "B", 0.7, 6.4, 4.3)
]).toDF(["user_id", "type", "d1", 'd2', 'd3'])
df.show()
给出:
+-------+----+---+---+----+
|user_id|type| d1| d2| d3|
+-------+----+---+---+----+
| c1| A|3.4|0.4| 3.5|
| c1| B|9.6|0.0| 0.0|
| c1| A|2.8|0.4| 0.3|
| c1| B|5.4|0.2|0.11|
| c2| A|0.0|9.7| 0.3|
| c2| B|9.6|8.6| 0.1|
| c2| A|7.3|9.1| 7.0|
| c2| B|0.7|6.4| 4.3|
+-------+----+---+---+----+
我通过 type
列对它进行了透视,将结果与 sum()
:
data_wide = df.groupBy('user_id')\
.pivot('type').sum()
data_wide.show()
给出:
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
|user_id| A_sum(`d1`)| A_sum(`d2`)|A_sum(`d3`)| B_sum(`d1`)|B_sum(`d2`)| B_sum(`d3`)|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
| c2| 7.3|18.799999999999997| 7.3|10.299999999999999| 15.0|4.3999999999999995|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
现在,生成的列名包含 `(波浪号)字符,例如,在向量汇编器中引入这个新列是一个问题,因为它 returns syntax error in attribute name
.出于这个原因,我需要重命名列名,但在循环内或 reduce(lambda...)
函数内调用 withColumnRenamed
方法会花费很多时间(实际上我的 df 有 11.520 列)。
有什么方法可以在数据透视+聚合步骤中避免使用这个字符,或者递归分配一个依赖于新数据透视列名称的别名吗?
提前致谢
您可以使用alias
:
pivot
的聚合中进行重命名
import pyspark.sql.functions as f
data_wide = df.groupBy('user_id')\
.pivot('type')\
.agg(*[f.sum(x).alias(x) for x in df.columns if x not in {"user_id", "type"}])
data_wide.show()
#+-------+-----------------+------------------+----+------------------+----+------------------+
#|user_id| A_d1| A_d2|A_d3| B_d1|B_d2| B_d3|
#+-------+-----------------+------------------+----+------------------+----+------------------+
#| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
#| c2| 7.3|18.799999999999997| 7.3|10.299999999999999|15.0|4.3999999999999995|
#+-------+-----------------+------------------+----+------------------+----+------------------+
但是,这实际上与执行 pivot
并在之后重命名没有什么不同。这是此方法的执行计划:
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)
#AS `d2`#170, A, B, 0, 0), pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`) AS `d2`#170, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
与
import re
def clean_names(df):
p = re.compile("^(\w+?)_([a-z]+)\((\w+)\)(?:\(\))?")
return df.toDF(*[p.sub(r"_", c) for c in df.columns])
pivoted = df.groupBy('user_id').pivot('type').sum()
clean_names(pivoted).explain()
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
您会发现两者几乎完全相同。通过避免使用正则表达式,您可能会获得一些微小的加速,但与 pivot
.
编写了一个简单快速的函数来重命名 PySpark 数据透视表。享受! :)
# This function efficiently rename pivot tables' urgly names
def rename_pivot_cols(rename_df, remove_agg):
"""change spark pivot table's default ugly column names at ease.
Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`.
Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2`
"""
for column in rename_df.columns:
if remove_agg == True:
start_index = column.find('(')
end_index = column.find(')')
if (start_index > 0 and end_index > 0):
rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1])
else:
new_column = column.replace('(','_').replace(')','')
rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1])
return rename_df