在 spark scala 的数据框中为每个组采样不同数量的随机行

Sample a different number of random rows for every group in a dataframe in spark scala

目标是为每个组采样(不替换)数据框中不同数量的行。为特定组采样的行数在另一个数据框中。

示例:idDF 是要从中采样的数据帧。这些组由 ID 列表示。数据框 planDF 指定要为每个组采样的行数,其中 "datesToUse" 表示行数,"ID" 表示组。 "totalDates" 是该组的总行数,可能有用也可能没用。

最终结果应该是从第一组(ID 1)抽取3行,从第二组(ID 2)抽取2行,从第三组(ID 3)抽取1行。

val idDF = Seq(
  (1, "2017-10-03"),
  (1, "2017-10-22"),
  (1, "2017-11-01"),
  (1, "2017-10-02"),
  (1, "2017-10-09"),
  (1, "2017-12-24"),
  (1, "2017-10-20"),
  (2, "2017-11-17"),
  (2, "2017-11-12"),
  (2, "2017-12-02"),      
  (2, "2017-10-03"),
  (3, "2017-12-18"),
  (3, "2017-11-21"),
  (3, "2017-12-13"),
  (3, "2017-10-08"),
  (3, "2017-10-16"),
  (3, "2017-12-04")
 ).toDF("ID", "date")

val planDF = Seq(
  (1, 3, 7),
  (2, 2, 4),
  (3, 1, 6)
 ).toDF("ID", "datesToUse", "totalDates")

这是生成的数据框的示例:

+---+----------+
| ID|      date|
+---+----------+
|  1|2017-10-22|
|  1|2017-11-01|
|  1|2017-10-20|
|  2|2017-11-12|
|  2|2017-10-03|
|  3|2017-10-16|
+---+----------+

到目前为止,我尝试使用 DataFrame 的示例方法:https://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/sql/DataFrame.html 这是一个适用于整个数据框的示例。

def sampleDF(DF: DataFrame, datesToUse: Int, totalDates: Int): DataFrame = {
  val fraction = datesToUse/totalDates.toFloat.toDouble
  DF.sample(false, fraction)
}

我不知道如何为每个组使用这样的东西。我尝试将 planDF table 加入 idDF table 并使用 window 分区。

我的另一个想法是以某种方式创建一个随机标记为 True / false 的新列,然后对该列进行过滤。

假设您的 planDF 足够小,可以 collected,您可以使用 Scala 的 foldLeft 来遍历 id 列表并累积每个 [=14] 的样本数据帧=]:

import org.apache.spark.sql.{Row, DataFrame}

def sampleByIdDF(DF: DataFrame, id: Int, datesToUse: Int, totalDates: Int): DataFrame = {
  val fraction = datesToUse.toDouble / totalDates
  DF.where($"id" === id ).sample(false, fraction)
}

val emptyDF = Seq.empty[(Int, String)].toDF("ID", "date")

val planList = planDF.rdd.collect.map{ case Row(x: Int, y: Int, z: Int) => (x, y, z) }
// planList: Array[(Int, Int, Int)] = Array((1,3,7), (2,2,4), (3,1,6))

planList.foldLeft( emptyDF ){
  case (accDF: DataFrame, (id: Int, num: Int, total: Int)) =>
    accDF union sampleByIdDF(idDF, id, num, total)
}
// res1: org.apache.spark.sql.DataFrame = [ID: int, date: string]

// res1.show
// +---+----------+
// | ID|      date|
// +---+----------+
// |  1|2017-10-03|
// |  1|2017-11-01|
// |  1|2017-10-02|
// |  1|2017-12-24|
// |  1|2017-10-20|
// |  2|2017-11-17|
// |  2|2017-11-12|
// |  2|2017-12-02|
// |  3|2017-11-21|
// |  3|2017-12-13|
// +---+----------+

请注意,方法 sample() 不一定生成方法参数中指定的确切样本数。这是相关的 SO Q&A.

如果您的 planDF 很大,您可能不得不考虑使用 RDD 的 aggregate,它具有以下签名(跳过隐式参数):

def aggregate[U](zeroValue: U)(seqOp: (U, T) ⇒ U, combOp: (U, U) ⇒ U): U

它的工作方式有点像 foldLeft,只是它在一个分区内有一个累加运算符,另外还有一个用于合并来自不同分区的结果。

另一个完全保留在 Dataframes 中的选项是使用 planDF 计算概率,加入 idDF,附加一列随机数,然后进行过滤。有用的是,sql.functions 有一个 rand 函数。

import org.apache.spark.sql.functions._

import spark.implicits._

val probabilities = planDF.withColumn("prob", $"datesToUse" / $"totalDates")

val dfWithProbs = idDF.join(probabilities, Seq("ID"))
  .withColumn("rand", rand())
  .where($"rand" < $"prob")

(您需要仔细检查这不是整数除法。)