Spark 中的 Where 子句与结构数组中的元素之间?

Where clause in Spark with between for element in array of struct?

我有这两种情况 class :

case class Doc(posts: Seq[Post], test: String)
case class Post(postId: Int, createdTime: Long)

我创建了一个样本 df :

val df = spark.sparkContext.parallelize(Seq(
Doc(Seq(
  Post(1, 1),
  Post(2, 3),
  Post(3, 8),
  Post(4, 15)
), null),
Doc(Seq(
  Post(5, 6),
  Post(6, 9),
  Post(7, 12),
  Post(8, 20)
), "hello") )).toDF()

所以我想要的是,return 在线文档,其中 createTime 介于 x 和 y 之间。 例如,对于 x = 2 et y = 9,我希望此结果具有与原始 df 相同的模式:

+--------------+
|         posts|
+--------------+
|[[2,3], [3,8]]|
|[[5,6], [6,9]]|
+--------------+

所以我尝试了很多 where 的组合,但我没有用。 我尝试使用 map(_.filter(...)),但我不想做的问题 toDF().as[Doc]

有什么帮助吗? 谢谢

有几种方法可以做到这一点:

  • 通过使用 UDF
  • 通过使用分解和收集
  • 通过使用数据块工具

UDF

UDF 是万能的。您基本上创建了一个自定义函数来完成这项工作。与转换为数据集不同,它不会构建整个 Doc class,而是只会处理相关数据:

def f(posts: Seq[Row]): Seq[Post] = {
  posts.map(r => Post(r.getAs[Int](0), r.getAs[Long](1))).filter(p => p.createdTime > 3 && p.createdTime < 9))
}
val u = udf(f _)
val filtered = df.withColumn("posts", u($"posts"))

使用爆炸和 collect_list

df.withColumn("posts", explode($"posts")).filter($"posts.createdTime" > 3 && $"posts.createdTime" < 9).groupBy("test").agg(collect_list("posts").as("posts"))

这可能比前一个效率低,但它是一个衬里(并且在未来的某一时刻它可能会得到优化)。

使用数据块工具

如果您正在使用 Databricks Cloud,则可以使用高阶函数。 有关详细信息,请参阅 here。由于这不是一般的通用 spark 选项,因此我不会讨论它。
希望将来他们会将其集成到标准 spark 中(我在该主题上发现 this jira,但目前不支持)。