在 pyspark 中用整数编码列
Encode a column with integer in pyspark
我必须在 pyspark(spark 2.0) 的大 DataFrame 中对列进行编码。所有的值几乎都是唯一的(大约 10 亿个值)。
最好的选择可能是 StringIndexer,但由于某种原因它总是失败并终止我的 spark 会话。
我能以某种方式编写这样的函数吗:
id_dict() = dict()
def indexer(x):
id_dict.setdefault(x, len(id_dict))
return id_dict[x]
并将其映射到 DataFrame 并 id_dict 保存 items()?这个字典会在每个执行者上同步吗?
我需要所有这些来为 spark.mllib ALS 模型预处理元组 ('x', 3, 5)。
谢谢你。
StringIndexer
将所有标签保存在内存中,因此如果值几乎是唯一的,它就不会缩放。
您可以采用唯一值、排序和添加 id,这很昂贵,但在这种情况下更健壮:
from pyspark.sql.functions import monotonically_increasing_id
df = spark.createDataFrame(["a", "b", "c", "a", "d"], "string").toDF("value")
indexer = (df.select("value").distinct()
.orderBy("value")
.withColumn("label", monotonically_increasing_id()))
df.join(indexer, ["value"]).show()
# +-----+-----------+
# |value| label|
# +-----+-----------+
# | d|25769803776|
# | c|17179869184|
# | b| 8589934592|
# | a| 0|
# | a| 0|
# +-----+-----------+
请注意,标签不是连续的,可能与 运行 运行 不同,或者随着 spark.sql.shuffle.partitions
的变化而变化。如果不可接受,您将不得不使用 RDDs
:
from operator import itemgetter
indexer = (df.select("value").distinct()
.rdd.map(itemgetter(0)).zipWithIndex()
.toDF(["value", "label"]))
df.join(indexer, ["value"]).show()
# +-----+-----+
# |value|label|
# +-----+-----+
# | d| 0|
# | c| 1|
# | b| 2|
# | a| 3|
# | a| 3|
# +-----+-----+
我必须在 pyspark(spark 2.0) 的大 DataFrame 中对列进行编码。所有的值几乎都是唯一的(大约 10 亿个值)。 最好的选择可能是 StringIndexer,但由于某种原因它总是失败并终止我的 spark 会话。 我能以某种方式编写这样的函数吗:
id_dict() = dict()
def indexer(x):
id_dict.setdefault(x, len(id_dict))
return id_dict[x]
并将其映射到 DataFrame 并 id_dict 保存 items()?这个字典会在每个执行者上同步吗? 我需要所有这些来为 spark.mllib ALS 模型预处理元组 ('x', 3, 5)。 谢谢你。
StringIndexer
将所有标签保存在内存中,因此如果值几乎是唯一的,它就不会缩放。
您可以采用唯一值、排序和添加 id,这很昂贵,但在这种情况下更健壮:
from pyspark.sql.functions import monotonically_increasing_id
df = spark.createDataFrame(["a", "b", "c", "a", "d"], "string").toDF("value")
indexer = (df.select("value").distinct()
.orderBy("value")
.withColumn("label", monotonically_increasing_id()))
df.join(indexer, ["value"]).show()
# +-----+-----------+
# |value| label|
# +-----+-----------+
# | d|25769803776|
# | c|17179869184|
# | b| 8589934592|
# | a| 0|
# | a| 0|
# +-----+-----------+
请注意,标签不是连续的,可能与 运行 运行 不同,或者随着 spark.sql.shuffle.partitions
的变化而变化。如果不可接受,您将不得不使用 RDDs
:
from operator import itemgetter
indexer = (df.select("value").distinct()
.rdd.map(itemgetter(0)).zipWithIndex()
.toDF(["value", "label"]))
df.join(indexer, ["value"]).show()
# +-----+-----+
# |value|label|
# +-----+-----+
# | d| 0|
# | c| 1|
# | b| 2|
# | a| 3|
# | a| 3|
# +-----+-----+