我如何从 Apache Spark 中的 RFormula/RFormulaModel 获取索引映射的因子?

How do i get the factor to index mappings from RFormula/RFormulaModel in Apache Spark?

全部,

我有一个像下面这样的简单数据框

我正在使用 RFormula api 制作如下模型矩阵

val formula = "dep ~ indep"
val rF = new RFormula().setFormula(formula).setFeaturesCol("features").setLabelCol("label")
val rfModel = rF.fit(df)

其中 rfModel 是 RFormulaModel 类型。根据文档 here

分类变量 "indep" 的映射应该可以作为 pipelineModel 从这个对象访问,但这似乎是一个私有成员。

我的问题是如何从 RFormulaModel 对象中获取标签和相应的索引?我知道我可以使用转换后的数据帧的元数据并进行字符串操作,但是有没有一种直接的方法可以做到这一点?

感谢您的帮助!

想出了一个 hack,我必须将 RFormulaModel 写入磁盘,然后将 pipelineModel 部分作为 PipelineModel 读回。从那里我可以访问 StringIndexerModel 阶段,如此处所示

import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.feature.StringIndexerModel

rfModel.write.overwrite.save("/rfModel")
val pModel = PipelineModel.read.load("/rfModel/pipelineModel")

val strIndexers = pModel.stages.filter(stage => stage.isInstanceOf[StringIndexerModel])
val labelMaps = strIndexers.map(e  => { val i = e.asInstanceOf[StringIndexerModel]; (i.getInputCol, i.labels)})

在 pyspark 中,您可以从目标数据框中提取元数据,例如

for attr in df.schema["X"].metadata["ml_attr"]["attrs"]["numeric"]:
    print(f"idx: {attr['idx']}, name: {attr['name']}")

希望这也适用于scala!