如何检测pyspark中的空列

How to detect null column in pyspark

我有一个用一些空值定义的数据框。有些列完全是空值。

>> df.show()
+---+---+---+----+
|  A|  B|  C|   D|
+---+---+---+----+
|1.0|4.0|7.0|null|
|2.0|5.0|7.0|null|
|3.0|6.0|5.0|null|
+---+---+---+----+

在我的例子中,我想要 return 一个用空值填充的列名列表。我的想法是检测常量列(因为整个列包含相同的空值)。

我是这样做的:

nullCoulumns = [c for c, const in df.select([(min(c) == max(c)).alias(c) for c in df.columns]).first().asDict().items() if const] 

但这不会将空列视为常量,它仅适用于值。 那我应该怎么做呢?

将条件扩展到

from pyspark.sql.functions import min, max

((min(c).isNull() & max(c).isNull()) | (min(c) == max(c))).alias(c) 

或使用 eqNullSafe (PySpark 2.3):

(min(c).eqNullSafe(max(c))).alias(c) 

一种方法是隐式执行:select 每列,计算其 NULL 值,然后将其与总数或行数进行比较。使用您的数据,这将是:

spark.version
# u'2.2.0'

from pyspark.sql.functions import col

nullColumns = []
numRows = df.count()
for k in df.columns:
  nullRows = df.where(col(k).isNull()).count()
  if nullRows ==  numRows: # i.e. if ALL values are NULL
    nullColumns.append(k)

nullColumns
# ['D']

但是有一个更简单的方法:事实证明,函数 countDistinct,当应用于具有所有 NULL 值的列时,returns 零 (0):

from pyspark.sql.functions import countDistinct

df.agg(countDistinct(df.D).alias('distinct')).collect()
# [Row(distinct=0)]

所以 for 循环现在可以是:

nullColumns = []
for k in df.columns:
  if df.agg(countDistinct(df[k])).collect()[0][0] == 0:
    nullColumns.append(k)

nullColumns
# ['D']

UPDATE(评论后):在第二个解决方案中似乎可以避免 collect;由于 df.agg returns 只有一行的数据框,将 collect 替换为 take(1) 将安全地完成工作:

nullColumns = []
for k in df.columns:
  if df.agg(countDistinct(df[k])).take(1)[0][0] == 0:
    nullColumns.append(k)

nullColumns
# ['D']

这个怎么样?为了保证列是 all 空值,必须满足两个属性:

(1) 最小值等于最大值

(2) 最小值最大值为空

或者,等价地

(1) 最小值和最大值都等于 None

请注意,如果 属性 (2) 不满足,列值为 [null, 1, null, 1] 的情况将被错误报告,因为最小值和最大值将为 1

import pyspark.sql.functions as F

def get_null_column_names(df):
    column_names = []

    for col_name in df.columns:

        min_ = df.select(F.min(col_name)).first()[0]
        max_ = df.select(F.max(col_name)).first()[0]

        if min_ is None and max_ is None:
            column_names.append(col_name)

    return column_names

这是一个实践中的例子:

>>> rows = [(None, 18, None, None),
            (1, None, None, None),
            (1, 9, 4.0, None),
            (None, 0, 0., None)]

>>> schema = "a: int, b: int, c: float, d:int"

>>> df = spark.createDataFrame(data=rows, schema=schema)

>>> df.show()

+----+----+----+----+
|   a|   b|   c|   d|
+----+----+----+----+
|null|  18|null|null|
|   1|null|null|null|
|   1|   9| 4.0|null|
|null|   0| 0.0|null|
+----+----+----+----+

>>> get_null_column_names(df)
['d']