如何在 pyspark 的高基数分类列中有效地对具有低频率计数的级别进行分组?
How to efficiently group levels with low frequency counts in a high cardinality categorical column in pyspark?
我目前正在尝试找到在 StringType()
的列中出现频率较低的分类列中对级别进行分组的有效方法。我想根据百分比阈值执行此操作,即替换出现在少于 z%
行中的所有值。此外,重要的是我们可以 return 数值(在应用 StringIndexer
之后)和原始值之间的映射。
所以基本上阈值是 25%,这个数据帧:
+---+---+---+---+
| x1| x2| x3| x4|
+---+---+---+---+
| a| a| a| a|
| b| b| a| b|
| a| a| a| c|
| b| b| a| d|
| c| a| a| e|
+---+---+---+---+
应该变成这样:
+------+------+------+------+
|x1_new|x2_new|x3_new|x4_new|
+------+------+------+------+
| a| a| a| other|
| b| b| a| other|
| a| a| a| other|
| b| b| a| other|
| other| a| a| other|
+------+------+------+------+
其中 c
已在 x1
列中替换为 other
,并且所有值已在列 x4
中替换为 other
,因为它们出现在少于 25%
行中。
我希望使用常规 StringIndexer
,并利用值根据其频率排序的事实。我们可以计算要保留多少个值并用例如替换所有其他值-1
。这种方法的问题:这会在稍后 IndexToString
内引发错误,我认为是因为元数据丢失了。
我的问题;有没有好的方法来做到这一点?是否有我可能忽略的内置功能?有没有办法保留元数据?
提前致谢!
df = pd.DataFrame({'x1' : ['a','b','a','b','c'], # a: 0.4, b: 0.4, c: 0.2
'x2' : ['a','b','a','b','a'], # a: 0.6, b: 0.4, c: 0.0
'x3' : ['a','a','a','a','a'], # a: 1.0, b: 0.0, c: 0.0
'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
我做了一些进一步的调查,偶然发现 关于将元数据添加到 pyspark 中的列。基于此,我能够创建一个名为 group_low_freq
的函数,我认为它非常有效;它只使用 StringIndexer
一次,然后修改此列和元数据以将出现次数少于 x%
的所有元素放入一个名为 "other" 的单独组中。由于我们还修改了元数据,因此我们稍后可以在 IndexToString
上检索字符串。下面给出函数和例子:
代码:
import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext, Column
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count as sparkcount, when, lit
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline
import json
try:
sc
except NameError:
sc = ps.SparkContext()
sqlContext = SQLContext(sc)
from pyspark.sql.functions import col
def withMeta(self, alias, meta):
sc = ps.SparkContext._active_spark_context
jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))
def group_low_freq(df,inColumns,threshold=.01,group_text='other'):
"""
Index string columns and group all observations that occur in less then a threshold% of the rows in df per column.
:param df: A pyspark.sql.dataframe.DataFrame
:param inColumns: String columns that need to be indexed
:param group_text: String to use as replacement for the observations that need to be grouped.
"""
total = df.count()
for string_col in inColumns:
# Apply string indexer
pipeline = Pipeline(stages=[StringIndexer(inputCol=string_col, outputCol="ix_"+string_col)])
df = pipeline.fit(df).transform(df)
# Calculate the number of unique elements to keep
n_to_keep = df.groupby(string_col).agg((sparkcount(string_col)/total).alias('perc')).filter(col('perc')>threshold).count()
# If elements occur below (threshold * number of rows), replace them with n_to_keep.
this_meta = df.select('ix_' + string_col).schema.fields[0].metadata
if n_to_keep != len(this_meta['ml_attr']['vals']):
this_meta['ml_attr']['vals'] = this_meta['ml_attr']['vals'][0:(n_to_keep+1)]
this_meta['ml_attr']['vals'][n_to_keep] = group_text
df = df.withColumn('ix_'+string_col,when(col('ix_'+string_col)>=n_to_keep,lit(n_to_keep)).otherwise(col('ix_'+string_col)))
# add the new column with correct metadata, remove original.
df = df.withColumn('ix_'+string_col, withMeta(col('ix_'+string_col), "", this_meta))
return df
# SAMPLE DATA -----------------------------------------------------------------
df = pd.DataFrame({'x1' : ['a','b','a','b','c'], # a: 0.4, b: 0.4, c: 0.2
'x2' : ['a','b','a','b','a'], # a: 0.6, b: 0.4, c: 0.0
'x3' : ['a','a','a','a','a'], # a: 1.0, b: 0.0, c: 0.0
'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
# TEST THE FUNCTION -----------------------------------------------------------
df = group_low_freq(df,df.columns,0.25)
ix_cols = [x for x in df.columns if 'ix_' in x]
for string_col in ix_cols:
idx_to_string = IndexToString(inputCol=string_col, outputCol=string_col[3:]+'grouped')
df = idx_to_string.transform(df)
df.show()
阈值为 25% 的输出(因此每个组必须至少出现在 25% 的行中):
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| x1| x2| x3| x4|ix_x1|ix_x2|ix_x3|ix_x4|x1grouped|x2grouped|x3grouped|x4grouped|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| a| a| a| a| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| b| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| a| a| a| c| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| d| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| c| a| a| e| 2.0| 0.0| 0.0| 0.0| other| a| a| other|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
我目前正在尝试找到在 StringType()
的列中出现频率较低的分类列中对级别进行分组的有效方法。我想根据百分比阈值执行此操作,即替换出现在少于 z%
行中的所有值。此外,重要的是我们可以 return 数值(在应用 StringIndexer
之后)和原始值之间的映射。
所以基本上阈值是 25%,这个数据帧:
+---+---+---+---+
| x1| x2| x3| x4|
+---+---+---+---+
| a| a| a| a|
| b| b| a| b|
| a| a| a| c|
| b| b| a| d|
| c| a| a| e|
+---+---+---+---+
应该变成这样:
+------+------+------+------+
|x1_new|x2_new|x3_new|x4_new|
+------+------+------+------+
| a| a| a| other|
| b| b| a| other|
| a| a| a| other|
| b| b| a| other|
| other| a| a| other|
+------+------+------+------+
其中 c
已在 x1
列中替换为 other
,并且所有值已在列 x4
中替换为 other
,因为它们出现在少于 25%
行中。
我希望使用常规 StringIndexer
,并利用值根据其频率排序的事实。我们可以计算要保留多少个值并用例如替换所有其他值-1
。这种方法的问题:这会在稍后 IndexToString
内引发错误,我认为是因为元数据丢失了。
我的问题;有没有好的方法来做到这一点?是否有我可能忽略的内置功能?有没有办法保留元数据?
提前致谢!
df = pd.DataFrame({'x1' : ['a','b','a','b','c'], # a: 0.4, b: 0.4, c: 0.2
'x2' : ['a','b','a','b','a'], # a: 0.6, b: 0.4, c: 0.0
'x3' : ['a','a','a','a','a'], # a: 1.0, b: 0.0, c: 0.0
'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
我做了一些进一步的调查,偶然发现 group_low_freq
的函数,我认为它非常有效;它只使用 StringIndexer
一次,然后修改此列和元数据以将出现次数少于 x%
的所有元素放入一个名为 "other" 的单独组中。由于我们还修改了元数据,因此我们稍后可以在 IndexToString
上检索字符串。下面给出函数和例子:
代码:
import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext, Column
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count as sparkcount, when, lit
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline
import json
try:
sc
except NameError:
sc = ps.SparkContext()
sqlContext = SQLContext(sc)
from pyspark.sql.functions import col
def withMeta(self, alias, meta):
sc = ps.SparkContext._active_spark_context
jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))
def group_low_freq(df,inColumns,threshold=.01,group_text='other'):
"""
Index string columns and group all observations that occur in less then a threshold% of the rows in df per column.
:param df: A pyspark.sql.dataframe.DataFrame
:param inColumns: String columns that need to be indexed
:param group_text: String to use as replacement for the observations that need to be grouped.
"""
total = df.count()
for string_col in inColumns:
# Apply string indexer
pipeline = Pipeline(stages=[StringIndexer(inputCol=string_col, outputCol="ix_"+string_col)])
df = pipeline.fit(df).transform(df)
# Calculate the number of unique elements to keep
n_to_keep = df.groupby(string_col).agg((sparkcount(string_col)/total).alias('perc')).filter(col('perc')>threshold).count()
# If elements occur below (threshold * number of rows), replace them with n_to_keep.
this_meta = df.select('ix_' + string_col).schema.fields[0].metadata
if n_to_keep != len(this_meta['ml_attr']['vals']):
this_meta['ml_attr']['vals'] = this_meta['ml_attr']['vals'][0:(n_to_keep+1)]
this_meta['ml_attr']['vals'][n_to_keep] = group_text
df = df.withColumn('ix_'+string_col,when(col('ix_'+string_col)>=n_to_keep,lit(n_to_keep)).otherwise(col('ix_'+string_col)))
# add the new column with correct metadata, remove original.
df = df.withColumn('ix_'+string_col, withMeta(col('ix_'+string_col), "", this_meta))
return df
# SAMPLE DATA -----------------------------------------------------------------
df = pd.DataFrame({'x1' : ['a','b','a','b','c'], # a: 0.4, b: 0.4, c: 0.2
'x2' : ['a','b','a','b','a'], # a: 0.6, b: 0.4, c: 0.0
'x3' : ['a','a','a','a','a'], # a: 1.0, b: 0.0, c: 0.0
'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
# TEST THE FUNCTION -----------------------------------------------------------
df = group_low_freq(df,df.columns,0.25)
ix_cols = [x for x in df.columns if 'ix_' in x]
for string_col in ix_cols:
idx_to_string = IndexToString(inputCol=string_col, outputCol=string_col[3:]+'grouped')
df = idx_to_string.transform(df)
df.show()
阈值为 25% 的输出(因此每个组必须至少出现在 25% 的行中):
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| x1| x2| x3| x4|ix_x1|ix_x2|ix_x3|ix_x4|x1grouped|x2grouped|x3grouped|x4grouped|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| a| a| a| a| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| b| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| a| a| a| c| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| d| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| c| a| a| e| 2.0| 0.0| 0.0| 0.0| other| a| a| other|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+