Spark - 如何在 DataFrame 中连接当前和以前的记录并为所有此类事件分配原始字段
Spark - How to join current and previous records in a DataFrame and assign an original field to all such occurences
我需要扫描 Hive table 并将序列中第一个记录的值添加到所有链接的记录。
逻辑是:-
- 查找第一条记录(其中 previous_id 为空白)。
- 查找下一条记录 (current_id = previous_id)。
- 重复直到没有更多链接记录。
- 将原始记录中的列添加到所有链接记录。
- 将结果输出到 Hive table。
示例源数据:-
current_id previous_id start_date
---------- ----------- ----------
100 01/01/2001
200 100 02/02/2002
300 200 03/03/2003
示例输出数据:-
current_id start_date
---------- ----------
100 01/01/2001
200 01/01/2001
300 01/01/2001
我可以通过从源 table 创建两个 DataFrame 并执行多个连接来实现这一点。但是,这种方法似乎并不理想,因为必须缓存数据以避免每次迭代都重新查询源数据。
关于如何解决这个问题有什么建议吗?
我不确定这是一种适合 Spark 的方法。
缺少数据的分组 ID/键。
不确定 Catalyst 将如何对此进行优化 - 稍后会查看。如果太大会出现内存错误?
使数据更复杂,这确实有效。这里是:
# No grouping key evident, more a linked list with asc current_ids.
# Added more complexity to the example.
# Questions open on performance at scale. Interested to see how well Catalyst handles this.
# Need really some grouping id/key in the data.
from pyspark.sql import functions as f
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
# Started from dataframe.
# Some more realistic data? At least more complex.
columns = ['current_id', 'previous_id', 'start_date']
vals = [
(100, None, '2001/01/01'),
(200, 100, '2002/02/02'),
(300, 200, '2003/03/03'),
(400, None, '2005/01/01'),
(500, 400, '2006/02/02'),
(600, 300, '2007/02/02'),
(700, 600, '2008/02/02'),
(800, None, '2009/02/02'),
(900, 800, '2010/02/02')
]
df = spark.createDataFrame(vals, columns)
df.createOrReplaceTempView("trans")
# Starting data. The null / None entries.
df2 = spark.sql("""
select *
from trans
where previous_id is null
""")
df2.cache
df2.createOrReplaceTempView("trans_0")
# Loop through the stuff based on traversing the list elements until exhaustion of data, and, write to dynamically named TempViews.
# May need to checkpoint? Depends on depth of chain of linked items.
# Spark not well suited to this type of processing.
dfX_cnt = 1
cnt = 1
while (dfX_cnt != 0):
tabname_prev = 'trans_' + str(cnt-1)
tabname = 'trans_' + str(cnt)
query = "select t2.current_id, t2.previous_id, t1.start_date from {} t1, trans t2 where t1.current_id = t2.previous_id".format(tabname_prev)
dfX = spark.sql(query)
dfX.cache
dfX_cnt = dfX.count()
if (dfX_cnt!=0):
#print('Looping for dynamic creation of TempViews')
dfX.createOrReplaceTempView(tabname)
cnt=cnt+1
# Reduce the TempViews all to one DF. Can reduce an array of DF's as well, but could not find my notes here in this regard.
# Will memory errors occur?
from pyspark.sql.types import *
fields = [StructField('current_id', LongType(), False),
StructField('previous_id', LongType(), True),
StructField('start_date', StringType(), False)]
schema = StructType(fields)
dfZ = spark.createDataFrame(sc.emptyRDD(), schema)
for i in range(0,cnt,1):
tabname = 'trans_' + str(i)
query = "select * from {}".format(tabname)
df = spark.sql(query)
dfZ = dfZ.union(df)
# Show final results.
dfZ.select('current_id', 'start_date').sort(col('current_id')).show()
returns:
+----------+----------+
|current_id|start_date|
+----------+----------+
| 100|2001/01/01|
| 200|2001/01/01|
| 300|2001/01/01|
| 400|2005/01/01|
| 500|2005/01/01|
| 600|2001/01/01|
| 700|2001/01/01|
| 800|2009/02/02|
| 900|2009/02/02|
+----------+----------+
我认为您可以使用 GraphFrames Connected components
来完成此操作
它将帮助您避免自己编写检查点和循环逻辑。本质上,您从 current_id
和 previous_id
对创建一个图,并将 GraphFrames 用于每个顶点的组件。然后可以将生成的 DataFrame 连接到原始 DataFrame 以获得 start_date
.
from graphframes import *
sc.setCheckpointDir("/tmp/chk")
input = spark.createDataFrame([
(100, None, "2001-01-01"),
(200, 100, "2002-02-02"),
(300, 200, "2003-03-03"),
(400, None, "2004-04-04"),
(500, 400, "2005-05-05"),
(600, 500, "2006-06-06"),
(700, 300, "2007-07-07")
], ["current_id", "previous_id", "start_date"])
input.show()
vertices = input.select(input.current_id.alias("id"))
edges = input.select(input.current_id.alias("src"), input.previous_id.alias("dst"))
graph = GraphFrame(vertices, edges)
result = graph.connectedComponents()
result.join(input.previous_id.isNull(), result.component == input.current_id)\
.select(result.id.alias("current_id"), input.start_date)\
.orderBy("current_id")\
.show()
结果如下:
+----------+----------+
|current_id|start_date|
+----------+----------+
| 100|2001-01-01|
| 200|2001-01-01|
| 300|2001-01-01|
| 400|2004-04-04|
| 500|2004-04-04|
| 600|2004-04-04|
| 700|2001-01-01|
+----------+----------+
感谢您在此处发布的建议。在尝试了各种方法之后,我采用了以下解决方案,该解决方案适用于多次迭代(例如 20 次循环),并且不会导致任何内存问题。
"Physical Plan" 仍然很大,但缓存意味着跳过大部分步骤,保持性能可接受。
input = spark.createDataFrame([
(100, None, '2001/01/01'),
(200, 100, '2002/02/02'),
(300, 200, '2003/03/03'),
(400, None, '2005/01/01'),
(500, 400, '2006/02/02'),
(600, 300, '2007/02/02'),
(700, 600, '2008/02/02'),
(800, None, '2009/02/02'),
(900, 800, '2010/02/02')
], ["current_id", "previous_id", "start_date"])
input.createOrReplaceTempView("input")
cur = spark.sql("select * from input where previous_id is null")
nxt = spark.sql("select * from input where previous_id is not null")
cur.cache()
nxt.cache()
cur.createOrReplaceTempView("cur0")
nxt.createOrReplaceTempView("nxt")
i = 1
while True:
spark.sql("set table_name=cur" + str(i - 1))
cur = spark.sql(
"""
SELECT nxt.current_id as current_id,
nxt.previous_id as previous_id,
cur.start_date as start_date
FROM ${table_name} cur,
nxt nxt
WHERE cur.current_id = nxt.previous_id
""").cache()
cur.createOrReplaceTempView("cur" + str(i))
i = i + 1
if cur.count() == 0:
break
for x in range(0, i):
spark.sql("set table_name=cur" + str(x))
cur = spark.sql("select * from ${table_name}")
if x == 0:
out = cur
else:
out = out.union(cur)
我需要扫描 Hive table 并将序列中第一个记录的值添加到所有链接的记录。
逻辑是:-
- 查找第一条记录(其中 previous_id 为空白)。
- 查找下一条记录 (current_id = previous_id)。
- 重复直到没有更多链接记录。
- 将原始记录中的列添加到所有链接记录。
- 将结果输出到 Hive table。
示例源数据:-
current_id previous_id start_date
---------- ----------- ----------
100 01/01/2001
200 100 02/02/2002
300 200 03/03/2003
示例输出数据:-
current_id start_date
---------- ----------
100 01/01/2001
200 01/01/2001
300 01/01/2001
我可以通过从源 table 创建两个 DataFrame 并执行多个连接来实现这一点。但是,这种方法似乎并不理想,因为必须缓存数据以避免每次迭代都重新查询源数据。
关于如何解决这个问题有什么建议吗?
我不确定这是一种适合 Spark 的方法。
缺少数据的分组 ID/键。
不确定 Catalyst 将如何对此进行优化 - 稍后会查看。如果太大会出现内存错误?
使数据更复杂,这确实有效。这里是:
# No grouping key evident, more a linked list with asc current_ids.
# Added more complexity to the example.
# Questions open on performance at scale. Interested to see how well Catalyst handles this.
# Need really some grouping id/key in the data.
from pyspark.sql import functions as f
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
# Started from dataframe.
# Some more realistic data? At least more complex.
columns = ['current_id', 'previous_id', 'start_date']
vals = [
(100, None, '2001/01/01'),
(200, 100, '2002/02/02'),
(300, 200, '2003/03/03'),
(400, None, '2005/01/01'),
(500, 400, '2006/02/02'),
(600, 300, '2007/02/02'),
(700, 600, '2008/02/02'),
(800, None, '2009/02/02'),
(900, 800, '2010/02/02')
]
df = spark.createDataFrame(vals, columns)
df.createOrReplaceTempView("trans")
# Starting data. The null / None entries.
df2 = spark.sql("""
select *
from trans
where previous_id is null
""")
df2.cache
df2.createOrReplaceTempView("trans_0")
# Loop through the stuff based on traversing the list elements until exhaustion of data, and, write to dynamically named TempViews.
# May need to checkpoint? Depends on depth of chain of linked items.
# Spark not well suited to this type of processing.
dfX_cnt = 1
cnt = 1
while (dfX_cnt != 0):
tabname_prev = 'trans_' + str(cnt-1)
tabname = 'trans_' + str(cnt)
query = "select t2.current_id, t2.previous_id, t1.start_date from {} t1, trans t2 where t1.current_id = t2.previous_id".format(tabname_prev)
dfX = spark.sql(query)
dfX.cache
dfX_cnt = dfX.count()
if (dfX_cnt!=0):
#print('Looping for dynamic creation of TempViews')
dfX.createOrReplaceTempView(tabname)
cnt=cnt+1
# Reduce the TempViews all to one DF. Can reduce an array of DF's as well, but could not find my notes here in this regard.
# Will memory errors occur?
from pyspark.sql.types import *
fields = [StructField('current_id', LongType(), False),
StructField('previous_id', LongType(), True),
StructField('start_date', StringType(), False)]
schema = StructType(fields)
dfZ = spark.createDataFrame(sc.emptyRDD(), schema)
for i in range(0,cnt,1):
tabname = 'trans_' + str(i)
query = "select * from {}".format(tabname)
df = spark.sql(query)
dfZ = dfZ.union(df)
# Show final results.
dfZ.select('current_id', 'start_date').sort(col('current_id')).show()
returns:
+----------+----------+
|current_id|start_date|
+----------+----------+
| 100|2001/01/01|
| 200|2001/01/01|
| 300|2001/01/01|
| 400|2005/01/01|
| 500|2005/01/01|
| 600|2001/01/01|
| 700|2001/01/01|
| 800|2009/02/02|
| 900|2009/02/02|
+----------+----------+
我认为您可以使用 GraphFrames Connected components
来完成此操作它将帮助您避免自己编写检查点和循环逻辑。本质上,您从 current_id
和 previous_id
对创建一个图,并将 GraphFrames 用于每个顶点的组件。然后可以将生成的 DataFrame 连接到原始 DataFrame 以获得 start_date
.
from graphframes import *
sc.setCheckpointDir("/tmp/chk")
input = spark.createDataFrame([
(100, None, "2001-01-01"),
(200, 100, "2002-02-02"),
(300, 200, "2003-03-03"),
(400, None, "2004-04-04"),
(500, 400, "2005-05-05"),
(600, 500, "2006-06-06"),
(700, 300, "2007-07-07")
], ["current_id", "previous_id", "start_date"])
input.show()
vertices = input.select(input.current_id.alias("id"))
edges = input.select(input.current_id.alias("src"), input.previous_id.alias("dst"))
graph = GraphFrame(vertices, edges)
result = graph.connectedComponents()
result.join(input.previous_id.isNull(), result.component == input.current_id)\
.select(result.id.alias("current_id"), input.start_date)\
.orderBy("current_id")\
.show()
结果如下:
+----------+----------+
|current_id|start_date|
+----------+----------+
| 100|2001-01-01|
| 200|2001-01-01|
| 300|2001-01-01|
| 400|2004-04-04|
| 500|2004-04-04|
| 600|2004-04-04|
| 700|2001-01-01|
+----------+----------+
感谢您在此处发布的建议。在尝试了各种方法之后,我采用了以下解决方案,该解决方案适用于多次迭代(例如 20 次循环),并且不会导致任何内存问题。
"Physical Plan" 仍然很大,但缓存意味着跳过大部分步骤,保持性能可接受。
input = spark.createDataFrame([
(100, None, '2001/01/01'),
(200, 100, '2002/02/02'),
(300, 200, '2003/03/03'),
(400, None, '2005/01/01'),
(500, 400, '2006/02/02'),
(600, 300, '2007/02/02'),
(700, 600, '2008/02/02'),
(800, None, '2009/02/02'),
(900, 800, '2010/02/02')
], ["current_id", "previous_id", "start_date"])
input.createOrReplaceTempView("input")
cur = spark.sql("select * from input where previous_id is null")
nxt = spark.sql("select * from input where previous_id is not null")
cur.cache()
nxt.cache()
cur.createOrReplaceTempView("cur0")
nxt.createOrReplaceTempView("nxt")
i = 1
while True:
spark.sql("set table_name=cur" + str(i - 1))
cur = spark.sql(
"""
SELECT nxt.current_id as current_id,
nxt.previous_id as previous_id,
cur.start_date as start_date
FROM ${table_name} cur,
nxt nxt
WHERE cur.current_id = nxt.previous_id
""").cache()
cur.createOrReplaceTempView("cur" + str(i))
i = i + 1
if cur.count() == 0:
break
for x in range(0, i):
spark.sql("set table_name=cur" + str(x))
cur = spark.sql("select * from ${table_name}")
if x == 0:
out = cur
else:
out = out.union(cur)