拟合管道和处理数据

Fitting pipeline and processing the data

我有一个包含文本的文件。我想要做的是使用管道对文本进行标记化,删除停用词并生成 2-grams。

到目前为止我做了什么:

第 1 步:读取文件

val data = sparkSession.read.text("data.txt").toDF("text")

第 2 步:构建管道

val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")

val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
val model = pipeline.fit(data)

我知道 pipeline.fit(data) 会生成 PipelineModel 但是我不知道如何使用 PipelineModel

如有任何帮助,我们将不胜感激。

当您 运行 val model = pipeline.fit(data) 代码时,所有 Estimator 阶段(即:机器学习任务,如分类、回归、聚类等)都适合数据和 Transformer 舞台已创建。您只有 Transformer 个阶段,因为您在此管道中进行功能创建。

为了执行您的模型,现在只包含 Transformer 个阶段,您需要 运行 val results = model.transform(data)。这将针对您的数据框执行每个 Transformer 阶段。因此,在 model.transform(data) 过程结束时,您将拥有一个由原始行、Tokenizer 输出、StopWordsRemover 输出以及最后的 NGram 结果组成的数据帧。

可以通过 SparkSQL 查询在特征创建完成后发现前 5 个 ngram。先对ngram列进行爆破,然后对ngrams进行分组统计,按统计的列降序排列,然后执行show(5)。或者,您可以使用 "LIMIT 5 方法代替 show(5).

顺便说一句,您可能应该将对象名称更改为非标准 class 名称。否则你会得到一个模糊的范围错误。

代码:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.Tokenizer
import org.apache.spark.sql.SparkSession._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.NGram
import org.apache.spark.ml.feature.StopWordsRemover
import org.apache.spark.ml.{Pipeline, PipelineModel}

object NGramPipeline {
    def main() {
        val sparkSession = SparkSession.builder.appName("NGram Pipeline").getOrCreate()

        val sc = sparkSession.sparkContext

        val data = sparkSession.read.text("quangle.txt").toDF("text")

        val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
        val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
        val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")

        val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
        val model = pipeline.fit(data)

        val results = model.transform(data)

        val explodedNGrams = results.withColumn("explNGrams", explode($"ngrams"))
        explodedNGrams.groupBy("explNGrams").agg(count("*") as "ngramCount").orderBy(desc("ngramCount")).show(10,false)

    }
}
NGramPipeline.main()



输出:

+-----------------+----------+
|explNGrams       |ngramCount|
+-----------------+----------+
|quangle wangle   |9         |
|wangle quee.     |4         |
|'mr. quangle     |3         |
|said, --         |2         |
|wangle said      |2         |
|crumpetty tree   |2         |
|crumpetty tree,  |2         |
|quangle wangle,  |2         |
|crumpetty tree,--|2         |
|blue babboon,    |2         |
+-----------------+----------+
only showing top 10 rows

请注意,存在导致行重复的语法(逗号、破折号等)。执行 ngram 时,过滤我们的语法通常是个好主意。您通常可以使用正则表达式执行此操作。