Pyspark:从密集向量列中获取新列中每一行的最大预测值

Pyspark: Get the maximum prediction value for each row in a new column from a dense vector column

我有一个 pyspark 数据框,我在上面应用了多类数据的随机分类器模型(来自 pyspark.ml.classification import RandomForestClassifier)。

现在,我有了预测和概率列(密集向量列)。我想要来自与预测相对应的可用概率列的新列 中的 单个最高概率。可以告诉我一个方法吗?

--------------------+----------+--------------+
|         probability|prediction|predictedLabel|
+--------------------+----------+--------------+
|[0.04980166062108...|       9.0|          73.0|
|[0.09709955311030...|       2.0|          92.0|
|[0.00206441341895...|       1.0|          97.0|
|[0.01177280567423...|       8.0|          26.0|
|[0.09170364155771...|       4.0|          78.0|
|[0.09332145486133...|       0.0|          95.0|
|[0.15873541380236...|       0.0|          95.0|
|[0.21929050786626...|       0.0|          95.0|
|[0.08840100103254...|       1.0|          97.0|
|[0.06204585465363...|       1.0|          97.0|
|[0.06961837644280...|       1.0|          97.0|
|[0.04529447218955...|       1.0|          97.0|
|[0.02129073891494...|       2.0|          92.0|
|[0.02692350960234...|       1.0|          97.0|
|[0.02676868258573...|       8.0|          26.0|
|[0.01849528482881...|       1.0|          97.0|
|[0.10405735702064...|       1.0|          97.0|
|[0.01636762299564...|       1.0|          97.0|
|[0.01739759717529...|       1.0|          97.0|
|[0.02129073891494...|       2.0|          92.0|
+--------------------+----------+--------------+

你有一个密集的数组列,所以使用 array_max(来自 pyspark.sql.functions.array_max)更有意义。

示例(使用 documentation

from pyspark.sql.functions import array_max
df = spark.createDataFrame([([0.04, 0.03, 0.01],), ([0.09, 0.05, 0.09],)], ['probability'])
df = df.withColumn("max_prob",array_max(df.probability))
df.show()

更新:你可以在数组最大值之前使用 vector_to_array() 吗,比如

from pyspark.sql.functions import array_max
from pyspark.ml.functions import vector_to_array
df = df.withColumn("max_prob",array_max(vector_to_array(df.probability)))

这应该给你

+------------------+--------+
|       probability|max_prob|
+------------------+--------+
|[0.04, 0.03, 0.01]|    0.04|
|[0.09, 0.05, 0.09]|    0.09|
+------------------+--------+