Apache Spark 中的 DataFrame 相等性

DataFrame equality in Apache Spark

假设 df1df2 是 Apache Spark 中的两个 DataFrame,使用两种不同的机制计算,例如,Spark SQL 与 Scala/Java/Python API。

是否有一种惯用的方法来确定两个数据帧是否等价(相等,同构),其中等价由数据(每行的列名和列值)相同来确定,除了行的顺序& 列?

这个问题的动机是通常有很多方法可以计算一些大数据结果,每种方法都有自己的权衡。在探索这些权衡时,保持正确性很重要,因此需要在有意义的测试数据集上检查 equivalence/equality。

Apache Spark 测试套件中有一些标准方法,但是其中大部分涉及在本地收集数据,如果您想对大型 DataFrame 进行相等性测试,那么这可能不是合适的解决方案。

首先检查模式,然后你可以对 df3 进行交集并验证 df1、df2 和 df3 的计数是否都相等(但是这仅在没有重复的行,如果有不同的重复行时有效行此方法仍然可以 return true).

另一种选择是获取两个 DataFrame 的底层 RDD,映射到 (Row, 1),执行 reduceByKey 来计算每行的数量,然后将两个结果 RDD 组合在一起,然后执行常规操作如果任何迭代器不相等,则聚合和 return false。

我不知道惯用语,但我认为您可以获得一种可靠的方法来比较数据帧,如下所述。 (我使用 PySpark 进行说明,但该方法适用于多种语言。)

a = spark.range(5)
b = spark.range(5)

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0

这种方法可以正确处理 DataFrame 可能有重复行、不同顺序的行、不同顺序的 and/or 列的情况。

例如:

a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0

这种方法非常昂贵,但考虑到需要执行完整差异,大部分费用是不可避免的。这应该可以很好地扩展,因为它不需要在本地收集任何东西。如果放宽比较应考虑重复行的限制,则可以删除 groupBy() 并只执行 subtract(),这可能会显着加快速度。

Scala(PySpark 见下文)

spark-fast-tests 库有两种比较 DataFrame 的方法(我是库的创建者):

assertSmallDataFrameEquality方法在driver节点上收集DataFrame并进行比较

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  if (!actualDF.collect().sameElements(expectedDF.collect())) {
    throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
  }
}

assertLargeDataFrameEquality方法比较分布在多台机器上的DataFrames(代码基本是从spark-testing-base复制过来的)

def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  try {
    actualDF.rdd.cache
    expectedDF.rdd.cache

    val actualCount = actualDF.rdd.count
    val expectedCount = expectedDF.rdd.count
    if (actualCount != expectedCount) {
      throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
    }

    val expectedIndexValue = zipWithIndex(actualDF.rdd)
    val resultIndexValue = zipWithIndex(expectedDF.rdd)

    val unequalRDD = expectedIndexValue
      .join(resultIndexValue)
      .filter {
        case (idx, (r1, r2)) =>
          !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
      }

    val maxUnequalRowsToShow = 10
    assertEmpty(unequalRDD.take(maxUnequalRowsToShow))

  } finally {
    actualDF.rdd.unpersist()
    expectedDF.rdd.unpersist()
  }
}

assertSmallDataFrameEquality 对于小型 DataFrame 比较更快,我发现它对我的测试套件来说已经足够了。

PySpark

这是一个简单的函数,如果数据帧相等,return为真:

def are_dfs_equal(df1, df2):
    if df1.schema != df2.schema:
        return False
    if df1.collect() != df2.collect():
        return False
    return True

或简化

def are_dfs_equal(df1, df2): 
    return (df1.schema == df2.schema) and (df1.collect() == df2.collect())

您通常会在测试套件中执行 DataFrame 相等性比较,并且在比较失败时需要描述性错误消息(True / False return 值不调试时很有帮助)。

使用 chispa 库访问 assert_df_equality 方法,该方法 return 用于测试套件工作流的描述性错误消息。

您可以结合使用一点重复数据删除和完全外部联接来执行此操作。这种方法的优点是它不需要您向驱动程序收集结果,并且它避免了 运行 多个作业。

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

// Generate some random data.
def random(n: Int, s: Long) = {
  spark.range(n).select(
    (rand(s) * 10000).cast("int").as("a"),
    (rand(s + 5) * 1000).cast("int").as("b"))
}
val df1 = random(10000000, 34)
val df2 = random(10000000, 17)

// Move all the keys into a struct (to make handling nulls easy), deduplicate the given dataset
// and count the rows per key.
def dedup(df: Dataset[Row]): Dataset[Row] = {
  df.select(struct(df.columns.map(col): _*).as("key"))
    .groupBy($"key")
    .agg(count(lit(1)).as("row_count"))
}

// Deduplicate the inputs and join them using a full outer join. The result can contain
// the following things:
// 1. Both keys are not null (and thus equal), and the row counts are the same. The dataset
//    is the same for the given key.
// 2. Both keys are not null (and thus equal), and the row counts are not the same. The dataset
//    contains the same keys.
// 3. Only the right key is not null.
// 4. Only the left key is not null.
val joined = dedup(df1).as("l").join(dedup(df2).as("r"), $"l.key" === $"r.key", "full")

// Summarize the differences.
val summary = joined.select(
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" === $"l.row_count", 1)).as("left_right_same_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" =!= $"l.row_count", 1)).as("left_right_different_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNull, 1)).as("left_only"),
  count(when($"l.key".isNull && $"r.key".isNotNull, 1)).as("right_only"))
summary.show()

Java:

assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();

尝试执行以下操作:

df1.except(df2).isEmpty
try {
  return ds1.union(ds2)
          .groupBy(columns(ds1, ds1.columns()))
          .count()
          .filter("count % 2 > 0")
          .count()
      == 0;
} catch (Exception e) {
  return false;
}

Column[] columns(Dataset<Row> ds, String... columnNames) {
List<Column> l = new ArrayList<>();
for (String cn : columnNames) {
  l.add(ds.col(cn));
}
return l.stream().toArray(Column[]::new);}

columns 方法是补充方法,可以用 returns Seq

的任何方法代替

逻辑:

  1. 合并两个数据集,如果列不匹配,它将抛出异常,因此 return false。
  2. 如果列匹配,则对所有列进行 groupBy 并添加列计数。现在,所有行的计数都是 2 的倍数(即使是重复行)。
  3. 检查是否有任何行的计数不能被 2 整除,这些是额外的行。

一种可扩展且简单的方法是区分两个 DataFrame 并计算不匹配的行数:

df1.diff(df2).where($"diff" != "N").count

如果该数字不为零,则两个 DataFrame 不相等。

diff 转换由 spark-extension 提供。

它标识Inserted,Changed,Deleted 和uN-更改行。

有 4 个选项,具体取决于您是否有 重复 行。

假设我们有两个 DataFrames,z1 和 z1。选项 1/2 适用于没有 重复行的 行。您可以在 spark-shell.

中尝试这些
  • 选项 1:直接执行 except
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column

def isEqual(left: DataFrame, right: DataFrame): Boolean = {
   if(left.columns.length != right.columns.length) return false // column lengths don't match
   if(left.count != right.count) return false // record count don't match
   return left.except(right).isEmpty && right.except(left).isEmpty
}
  • 选项 2:按列生成行散列
def createHashColumn(df: DataFrame) : Column = {
   val colArr = df.columns
   md5(concat_ws("", (colArr.map(col(_))) : _*))
}

val z1SigDF = z1.select(col("index"), createHashColumn(z1).as("signature_z1"))
val z2SigDF = z2.select(col("index"), createHashColumn(z2).as("signature_z2"))
val joinDF = z1SigDF.join(z2SigDF, z1SigDF("index") === z2SigDF("index")).where($"signature_z1" =!= $"signature_z2").cache
// should be 0
joinDF.count
  • 选项 3:使用 GroupBy(对于具有重复行的 DataFrame)
val z1Grouped = z1.groupBy(z1.columns.map(c => z1(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")
val z2Grouped = z2.groupBy(z2.columns.map(c => z2(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")

val inZ1NotInZ2 = z1Grouped.except(z2Grouped).toDF()
val inZ2NotInZ1 = z2Grouped.except(z1Grouped).toDF()
// both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show
  • 选项 4,使用 exceptAll,它也适用于具有重复行的数据
// Source Code: https://github.com/apache/spark/blob/50538600ec/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2029
val inZ1NotInZ2 = z1.exceptAll(z2).toDF()
val inZ2NotInZ1 = z2.exceptAll(z1).toDF()
// same here, // both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show