Spark RDD 或 SQL 操作来计算条件计数

Spark RDD Or SQL operations to compute conditional counts

作为背景知识,我正在尝试在 Spark 中实现 Kaplan-Meier。特别是,我假设我有一个数据 frame/set,其中 Double 列表示为 DataInt 列名为 censorFlag0 值如果被审查,1 如果没有,更喜欢这个而不是 Boolean 类型)。

示例:

val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)] 

现在我需要计算一个列 wins 来计算每个 data 值的实例。我使用以下代码实现了这一点:

val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))

当我需要计算一个名为 atRisk 的量时,问题就来了,对于每个 data 的值,大于或等于它(累积过滤计数,如果你愿意的话)。

以下代码有效:

// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect 
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show

但是,用例涉及从列 data 本身 派生 bins,我宁愿将其保留为单列数据集(或最坏情况下的 RDD),但肯定不是本地数组。但这不起作用:

// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show

这也不行:

// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show

另一种 确实 有效但从并行化角度来看也不令人满意的方法是使用 Window:

// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation." 
atRiskCounts.show

最后一个解决方案没有用,因为它最终将我的数据洗牌到一个分区(在这种情况下,我不妨使用可行的选项 1)。

成功的方法很好,除了容器不是平行的,如果可能的话,我真的很想保留这一点。我查看了 groupBy 聚合,pivot 类型的聚合,但 none 似乎有意义。

我的问题是:有没有办法以分布式方式计算 atRisk 列?另外,为什么我在失败的解决方案中得到 NullPointerException

根据评论进行编辑:

我最初没有 post NullPointerException 因为它似乎没有包含任何有用的东西。我会记下这是通过自制软件安装在我的 Macbook Pro 上的 Spark(Spark 版本 2.2.1,独立本地主机模式)。

                18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
            java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$$anon.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
                at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
                at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
                at java.lang.Thread.run(Thread.java:748)
            18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$$anon.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
            org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $anonfun.apply(<console>:33)
                at $anonfun.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$$anon.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            Driver stacktrace:
              at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage.apply(DAGScheduler.scala:1505)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage.apply(DAGScheduler.scala:1504)
              at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
              ... 50 elided
            Caused by: java.lang.NullPointerException
              at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
              at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
              at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
              at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
              at $anonfun.apply(<console>:33)
              at $anonfun.apply(<console>:33)
              at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
              at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
              at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$$anon.hasNext(WholeStageCodegenExec.scala:395)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:234)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:228)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:827)
              at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
              at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
              at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
              at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
              at org.apache.spark.scheduler.Task.run(Task.scala:108)
              at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
              at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
              at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
              at java.lang.Thread.run(Thread.java:748)

我最好的猜测是 df("data").geq(x).count 行可能是 barfs 的一部分,因为并非每个节点都可能有 x,因此是一个空指针?

我没有测试过这个,所以语法可能很愚蠢,但我会做一系列的连接:

我相信你的第一个陈述等同于此——对于每个 data 值,计算有多少 wins

val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))

然后,如您所述,我们可以构建 bin 的数据框:

val distinctData = df.select($"data".as("dataBins")).distinct()

然后加入 >= 条件:

val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
  .groupBy($"data", $"wins")
  .count()

当您需要检查一列中的值以及该列中的所有其他值时,collection 是最重要的。而当需要对所有值进行校验时,则可以肯定该列的所有数据都需要累积在一个执行器或驱动器中。有你这样的需求,你就无法避免这一步。

现在的主要部分是如何定义其余步骤以受益于 spark 的并行化。我建议你 broadcast 收集的集合 (因为它只有一列的不同数据,所以它们不能很大)并使用 udf 函数进行检查gte 条件如下

首先你可以优化你的collection步骤作为

import org.apache.spark.sql.functions._
val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]

然后你broadcast收藏集

val broadcastedArray = sc.broadcast(collectedData)

下一步是定义一个 udf 函数并检查 gte 条件和 return counts

def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))

并将其用作

distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)

所以最后你应该

+----+----------+----+------+
|data|censorFlag|wins|atRisk|
+----+----------+----+------+
|4.5 |1         |1   |1     |
|0.7 |0         |0   |6     |
|2.3 |0         |0   |3     |
|1.0 |1         |1   |4     |
|0.8 |1         |2   |5     |
|0.8 |1         |2   |5     |
|4.0 |1         |1   |2     |
+----+----------+----+------+

我希望这是必需的 dataframe

我试过上面的例子(虽然不是最严格的!),一般来说,左边的 join 效果最好。

数据:

import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))

连接示例:

def runJoin(sc: SparkContext, df:DataFrame): Unit = {
  val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
  val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
  val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
  val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
  finalDF.show
}

广播示例:

def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
  val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
  val binsBroadcast = sc.broadcast(bins)
  val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
  val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
  finalDF.show
  binsBroadcast.destroy
}

以及测试代码:

var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

我运行这段代码针对不同大小的运行dom数据,提供手动bins数组(有些很g运行ular,50%的原始distinct数据,一些非常小,占原始不同数据的 10%),而且一直以来 join 方法似乎是最快的(尽管两者都得出相同的解决方案,所以这是一个加号!)。

平均而言,我发现 bin 数组越小,broadcast 方法的效果越好,但 join 似乎并没有受到太大影响。如果我有更多 time/resource 来测试这个,我会 运行 大量模拟以查看平均 运行 时间是什么样的,但现在我会接受@hoyland 的解决方案。

仍然不确定为什么原来的方法不起作用,所以欢迎对此发表评论。

请让我知道我的代码中的任何问题或改进!谢谢你们 :)