如何在 Spark 2.1 上更新 pyspark 数据帧元数据?

How to update pyspark dataframe metadata on Spark 2.1?

我遇到了 SparkML 的 OneHotEncoder 问题,因为它读取数据帧元数据以确定它应该为其创建的稀疏向量对象分配的值范围。

更具体地说,我正在使用包含 0 到 23 之间所有单个值的训练集对 "hour" 字段进行编码。

现在我正在使用管道的 "transform" 方法对单行数据框进行评分。

不幸的是,这会导致 OneHotEncoder

的编码不同的稀疏向量对象

(24,[5],[1.0]) 与 (11,[10],[1.0])

我已经记录了这个 here, but this was identified as duplicate. So in this 有一个解决方案发布来更新数据帧的元数据以反映 "hour" 字段的真实范围:

from pyspark.sql.functions import col

meta = {"ml_attr": {
    "vals": [str(x) for x in range(6)],   # Provide a set of levels
    "type": "nominal", 
    "name": "class"}}

loaded.transform(
    df.withColumn("class", col("class").alias("class", metadata=meta)) )

不幸的是我得到了这个错误:

TypeError: alias() got an unexpected keyword argument 'metadata'

在 PySpark 2.1 中,alias 方法没有参数 metadatadocs) - this became available in Spark 2.2; nevertheless, it is still possible to modify column metadata in PySpark < 2.2, thanks to the incredible Spark Gotchas, maintained by @eliasah and @zero323:

import json

from pyspark import SparkContext
from pyspark.sql import Column
from pyspark.sql.functions import col

spark.version
# u'2.1.1'

df = sc.parallelize((
        (0, "x", 2.0),
        (1, "y", 3.0),
        (2, "x", -1.0)
        )).toDF(["label", "x1", "x2"])

df.show()
# +-----+---+----+ 
# |label| x1|  x2|
# +-----+---+----+
# |    0|  x| 2.0|
# |    1|  y| 3.0|
# |    2|  x|-1.0|
# +-----+---+----+

假设我们想要强制我们的 label 数据介于 0 和 5 之间的可能性,尽管我们的数据框中的数据介于 0 和 2 之间,下面是我们应该如何修改列元数据:

def withMeta(self, alias, meta):
    sc = SparkContext._active_spark_context
    jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
    return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))

Column.withMeta = withMeta

# new metadata:
meta = {"ml_attr": {"name": "label_with_meta",
                    "type": "nominal",
                    "vals": [str(x) for x in range(6)]}}

df_with_meta = df.withColumn("label_with_meta", col("label").withMeta("", meta))

zero323 也向 致敬!