将熵计算从 Scala Spark 转换为 PySpark

Converting entropy calculation from Scala Spark to PySpark

环境:Spark 2.4.4

我正在尝试将以下代码从 Scala Spark 转换为 PySpark:

test.registerTempTable("test")

val df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("_1").rowsBetween(Long.MinValue, Long.MaxValue)

import org.apache.spark.sql.functions.sum

val p = $"_2" / sum($"_2").over(w)
val withP = df.withColumn("p", p)

import org.apache.spark.sql.functions.log2

val result = withP.groupBy($"_1").agg((-sum($"p" * log2($"p"))).alias("entropy"))

result.collect()

正在运行并输出所需的结果:

Array[org.apache.spark.sql.Row] = Array([179,0.1091158547868134], [178,0.181873874177682], [177,-0.0], [176,0.9182958340544896], [175,-0.0], [174,-0.0], [173,0.04848740692447222], [172,-0.0], [171,-0.0], [170,-0.0], [169,-...

PySpark 版本一直运行到最终版本,但随后导致 AnalysisException:

df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

from pyspark.sql import Window

w = Window.partitionBy("_1").rowsBetween(-9223372036854775808L, 9223372036854775807L)

from pyspark.sql.functions import sum 

p = df['_2'] / sum(df['_2']).over(w)
withP = df.withColumn("p", p)

from pyspark.sql.functions import log2 

result = withP.groupBy("_1").agg((-sum(p * log2(p))).alias("entropy"))

异常:

Fail to execute line 19: result = withP.groupBy("_1").agg(sum(p * log2(p)).alias("entropy"))
Traceback (most recent call last):
  File "/tmp/zeppelin_pyspark-6317327282796051870.py", line 380, in <module>
    exec(code, _zcUserQueryNameSpace)
  File "<stdin>", line 19, in <module>
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/group.py", line 115, in agg
    _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 69, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
AnalysisException: u'It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.;'

原始DataFrame的样本:

df = spark.createDataFrame([(1, 10), (1, 1), (2, 10), (3, 1), (3, 100)])

为什么 Scala 版本可以,而具有完全相同逻辑的 Pyspark 版本却不行?

列名称 p 和列对象 p 之间存在冲突。 您应该在总和聚合中使用 col("p") 。这应该可以正常工作:

result = withP.groupBy("_1").agg((-sum(col("p") * log2(col("p")))).alias("entropy"))