Spark 数据框中的序列
Sequences in Spark dataframe
我在 Spark 中有数据框。看起来像这样:
+-------+----------+-------+
| value| group| ts|
+-------+----------+-------+
| A| X| 1|
| B| X| 2|
| B| X| 3|
| D| X| 4|
| E| X| 5|
| A| Y| 1|
| C| Y| 2|
+-------+----------+-------+
Endgoal:我想找出有多少序列 A-B-E
(序列只是后续行的列表)。添加的约束是序列的后续部分最多可以分开 n
行。让我们考虑这个例子 n
是 2.
考虑组 X
。
在这种情况下,B
和 E
之间恰好有 1 个 D
(忽略多个连续的 B
)。这意味着 B
和 E
相隔 1 行,因此有一个序列 A-B-E
我考虑过使用 collect_list()
,创建一个字符串(如 DNA)并使用正则表达式进行子字符串搜索。但我想知道是否有更优雅的分布式方式,也许使用 window 函数?
编辑:
请注意,提供的数据框只是一个示例。真实的数据帧(以及组)可以是任意长的。
编辑以回答@Tim 的评论 + 修复 "AABE"
类型的模式
是的,使用 window 函数有帮助,但我创建了一个 id
来进行排序:
val df = List(
(1,"A","X",1),
(2,"B","X",2),
(3,"B","X",3),
(4,"D","X",4),
(5,"E","X",5),
(6,"A","Y",1),
(7,"C","Y",2)
).toDF("id","value","group","ts")
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy('group).orderBy('id)
然后 lag 将收集所需的内容,但需要一个函数来生成 Column
表达式(注意拆分以消除 "AABE" 的重复计数)。警告:这拒绝 "ABAEXX"):
类型的模式
def createSeq(m:Int) = split(
concat(
(1 to 2*m)
.map(i => coalesce(lag('value,-i).over(w),lit("")))
:_*),"A")(0)
val m=2
val tmp = df
.withColumn("seq",createSeq(m))
+---+-----+-----+---+----+
| id|value|group| ts| seq|
+---+-----+-----+---+----+
| 6| A| Y| 1| C|
| 7| C| Y| 2| |
| 1| A| X| 1|BBDE|
| 2| B| X| 2| BDE|
| 3| B| X| 3| DE|
| 4| D| X| 4| E|
| 5| E| X| 5| |
+---+-----+-----+---+----+
由于 Column
API 中可用的集合函数集很差,使用 UDF
完全避免正则表达式要容易得多
def patternInSeq(m: Int) = udf((str: String) => {
var notFound = str
.split("B")
.filter(_.contains("E"))
.filter(_.indexOf("E") <= m)
.isEmpty
!notFound
})
val res = tmp
.filter(('value === "A") && (locate("B",'seq) > 0))
.filter(locate("B",'seq) <= m && (locate("E",'seq) > 1))
.filter(patternInSeq(m)('seq))
.groupBy('group)
.count
res.show
+-----+-----+
|group|count|
+-----+-----+
| X| 1|
+-----+-----+
泛化(超出范围)
如果你想泛化更长的字母序列,问题必须被泛化。这可能是微不足道的,但在这种情况下,类型 ("ABAE") 的模式应该被拒绝(见评论)。因此,最简单的概括方法是在以下实现中采用成对规则(我添加了一个组 "Z" 来说明此算法的行为)
val df = List(
(1,"A","X",1),
(2,"B","X",2),
(3,"B","X",3),
(4,"D","X",4),
(5,"E","X",5),
(6,"A","Y",1),
(7,"C","Y",2),
( 8,"A","Z",1),
( 9,"B","Z",2),
(10,"D","Z",3),
(11,"B","Z",4),
(12,"E","Z",5)
).toDF("id","value","group","ts")
首先我们定义一对的逻辑
import org.apache.spark.sql.DataFrame
def createSeq(m:Int) = array((0 to 2*m).map(i => coalesce(lag('value,-i).over(w),lit(""))):_*)
def filterPairUdf(m: Int, t: (String,String)) = udf((ar: Array[String]) => {
val (a,b) = t
val foundAt = ar
.dropWhile(_ != a)
.takeWhile(_ != a)
.indexOf(b)
foundAt != -1 && foundAt <= m
})
然后我们定义一个函数,将此逻辑迭代应用于数据帧
def filterSeq(seq: List[String], m: Int)(df: DataFrame): DataFrame = {
var a = seq(0)
seq.tail.foldLeft(df){(df: DataFrame, b: String) => {
val res = df.filter(filterPairUdf(m,(a,b))('seq))
a = b
res
}}
}
得到简化和优化,因为我们首先过滤从第一个字符开始的序列
val m = 2
val tmp = df
.filter('value === "A") // reduce problem
.withColumn("seq",createSeq(m))
scala> tmp.show()
+---+-----+-----+---+---------------+
| id|value|group| ts| seq|
+---+-----+-----+---+---------------+
| 6| A| Y| 1| [A, C, , , ]|
| 8| A| Z| 1|[A, B, D, B, E]|
| 1| A| X| 1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+
val res = tmp.transform(filterSeq(List("A","B","E"),m))
scala> res.show()
+---+-----+-----+---+---------------+
| id|value|group| ts| seq|
+---+-----+-----+---+---------------+
| 1| A| X| 1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+
(transform
是DataFrame => DataFrame
变换的简单糖衣)
res
.groupBy('group)
.count
.show
+-----+-----+
|group|count|
+-----+-----+
| X| 1|
+-----+-----+
正如我所说,在扫描序列时有不同的方法来概括 "resetting rules",但是这个例子希望有助于实现更复杂的方法。
我在 Spark 中有数据框。看起来像这样:
+-------+----------+-------+
| value| group| ts|
+-------+----------+-------+
| A| X| 1|
| B| X| 2|
| B| X| 3|
| D| X| 4|
| E| X| 5|
| A| Y| 1|
| C| Y| 2|
+-------+----------+-------+
Endgoal:我想找出有多少序列 A-B-E
(序列只是后续行的列表)。添加的约束是序列的后续部分最多可以分开 n
行。让我们考虑这个例子 n
是 2.
考虑组 X
。
在这种情况下,B
和 E
之间恰好有 1 个 D
(忽略多个连续的 B
)。这意味着 B
和 E
相隔 1 行,因此有一个序列 A-B-E
我考虑过使用 collect_list()
,创建一个字符串(如 DNA)并使用正则表达式进行子字符串搜索。但我想知道是否有更优雅的分布式方式,也许使用 window 函数?
编辑:
请注意,提供的数据框只是一个示例。真实的数据帧(以及组)可以是任意长的。
编辑以回答@Tim 的评论 + 修复 "AABE"
类型的模式是的,使用 window 函数有帮助,但我创建了一个 id
来进行排序:
val df = List(
(1,"A","X",1),
(2,"B","X",2),
(3,"B","X",3),
(4,"D","X",4),
(5,"E","X",5),
(6,"A","Y",1),
(7,"C","Y",2)
).toDF("id","value","group","ts")
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy('group).orderBy('id)
然后 lag 将收集所需的内容,但需要一个函数来生成 Column
表达式(注意拆分以消除 "AABE" 的重复计数)。警告:这拒绝 "ABAEXX"):
def createSeq(m:Int) = split(
concat(
(1 to 2*m)
.map(i => coalesce(lag('value,-i).over(w),lit("")))
:_*),"A")(0)
val m=2
val tmp = df
.withColumn("seq",createSeq(m))
+---+-----+-----+---+----+
| id|value|group| ts| seq|
+---+-----+-----+---+----+
| 6| A| Y| 1| C|
| 7| C| Y| 2| |
| 1| A| X| 1|BBDE|
| 2| B| X| 2| BDE|
| 3| B| X| 3| DE|
| 4| D| X| 4| E|
| 5| E| X| 5| |
+---+-----+-----+---+----+
由于 Column
API 中可用的集合函数集很差,使用 UDF
def patternInSeq(m: Int) = udf((str: String) => {
var notFound = str
.split("B")
.filter(_.contains("E"))
.filter(_.indexOf("E") <= m)
.isEmpty
!notFound
})
val res = tmp
.filter(('value === "A") && (locate("B",'seq) > 0))
.filter(locate("B",'seq) <= m && (locate("E",'seq) > 1))
.filter(patternInSeq(m)('seq))
.groupBy('group)
.count
res.show
+-----+-----+
|group|count|
+-----+-----+
| X| 1|
+-----+-----+
泛化(超出范围)
如果你想泛化更长的字母序列,问题必须被泛化。这可能是微不足道的,但在这种情况下,类型 ("ABAE") 的模式应该被拒绝(见评论)。因此,最简单的概括方法是在以下实现中采用成对规则(我添加了一个组 "Z" 来说明此算法的行为)
val df = List(
(1,"A","X",1),
(2,"B","X",2),
(3,"B","X",3),
(4,"D","X",4),
(5,"E","X",5),
(6,"A","Y",1),
(7,"C","Y",2),
( 8,"A","Z",1),
( 9,"B","Z",2),
(10,"D","Z",3),
(11,"B","Z",4),
(12,"E","Z",5)
).toDF("id","value","group","ts")
首先我们定义一对的逻辑
import org.apache.spark.sql.DataFrame
def createSeq(m:Int) = array((0 to 2*m).map(i => coalesce(lag('value,-i).over(w),lit(""))):_*)
def filterPairUdf(m: Int, t: (String,String)) = udf((ar: Array[String]) => {
val (a,b) = t
val foundAt = ar
.dropWhile(_ != a)
.takeWhile(_ != a)
.indexOf(b)
foundAt != -1 && foundAt <= m
})
然后我们定义一个函数,将此逻辑迭代应用于数据帧
def filterSeq(seq: List[String], m: Int)(df: DataFrame): DataFrame = {
var a = seq(0)
seq.tail.foldLeft(df){(df: DataFrame, b: String) => {
val res = df.filter(filterPairUdf(m,(a,b))('seq))
a = b
res
}}
}
得到简化和优化,因为我们首先过滤从第一个字符开始的序列
val m = 2
val tmp = df
.filter('value === "A") // reduce problem
.withColumn("seq",createSeq(m))
scala> tmp.show()
+---+-----+-----+---+---------------+
| id|value|group| ts| seq|
+---+-----+-----+---+---------------+
| 6| A| Y| 1| [A, C, , , ]|
| 8| A| Z| 1|[A, B, D, B, E]|
| 1| A| X| 1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+
val res = tmp.transform(filterSeq(List("A","B","E"),m))
scala> res.show()
+---+-----+-----+---+---------------+
| id|value|group| ts| seq|
+---+-----+-----+---+---------------+
| 1| A| X| 1|[A, B, B, D, E]|
+---+-----+-----+---+---------------+
(transform
是DataFrame => DataFrame
变换的简单糖衣)
res
.groupBy('group)
.count
.show
+-----+-----+
|group|count|
+-----+-----+
| X| 1|
+-----+-----+
正如我所说,在扫描序列时有不同的方法来概括 "resetting rules",但是这个例子希望有助于实现更复杂的方法。