TypeError: Column is not iterable - How to iterate over ArrayType()?

TypeError: Column is not iterable - How to iterate over ArrayType()?

考虑以下 DataFrame:

+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[john, sam, jane]      |
|pet   |[whiskers, rover, fido]|
+------+-----------------------+

可以使用以下代码创建:

import pyspark.sql.functions as f
data = [
    ('person', ['john', 'sam', 'jane']),
    ('pet', ['whiskers', 'rover', 'fido'])
]

df = sqlCtx.createDataFrame(data, ["type", "names"])
df.show(truncate=False)

有没有办法通过对每个元素应用函数来直接修改 ArrayType()"names",而不使用 udf

例如,假设我想将函数 foo 应用到 "names" 列。 (我将使用 foostr.upper 的示例仅用于说明目的,但我的问题是关于可应用于可迭代元素的任何有效函数。)

foo = lambda x: x.upper()  # defining it as str.upper as an example
df.withColumn('X', [foo(x) for x in f.col("names")]).show()

TypeError: Column is not iterable

我可以使用 udf:

foo_udf = f.udf(lambda row: [foo(x) for x in row], ArrayType(StringType()))
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
#+------+-----------------------+
#|type  |names                  |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE]      |
#|pet   |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+

在这个具体的例子中,我可以通过展开列来避免udf,调用pyspark.sql.functions.upper(),然后groupBycollect_list:

df.select('type', f.explode('names').alias('name'))\
    .withColumn('name', f.upper(f.col('name')))\
    .groupBy('type')\
    .agg(f.collect_list('name').alias('names'))\
    .show(truncate=False)
#+------+-----------------------+
#|type  |names                  |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE]      |
#|pet   |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+

但这是很多代码来做一些简单的事情。是否有更直接的方法使用 spark-dataframe 函数迭代 ArrayType() 的元素?

是的,您可以将其转换为 RDD,然后再转换回 DF。

>>> df.show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[john, sam, jane]      |
|pet   |[whiskers, rover, fido]|
+------+-----------------------+

>>> df.rdd.mapValues(lambda x: [y.upper() for y in x]).toDF(["type","names"]).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

Spark < 2.4 中,您可以使用用户定义的函数:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, DataType, StringType

def transform(f, t=StringType()):
    if not isinstance(t, DataType):
       raise TypeError("Invalid type {}".format(type(t)))
    @udf(ArrayType(t))
    def _(xs):
        if xs is not None:
            return [f(x) for x in xs]
    return _

foo_udf = transform(str.upper)

df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

考虑到 explode + collect_list 习语的高成本,尽管其内在成本,这种方法几乎是唯一首选。

Spark 2.4 或更高版本中,您可以使用 transform* with upper (see SPARK-23909):

from pyspark.sql.functions import expr

df.withColumn(
    'names', expr('transform(names, x -> upper(x))')
).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

也可以使用pandas_udf

from pyspark.sql.functions import pandas_udf, PandasUDFType

def transform_pandas(f, t=StringType()):
    if not isinstance(t, DataType):
       raise TypeError("Invalid type {}".format(type(t)))
    @pandas_udf(ArrayType(t), PandasUDFType.SCALAR)
    def _(xs):
        return xs.apply(lambda xs: [f(x) for x in xs] if xs is not None else xs)
    return _

foo_udf_pandas = transform_pandas(str.upper)

df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

尽管只有最新的 Arrow / PySpark 组合支持处理 ArrayType 列 (SPARK-24259, SPARK-21187)。尽管如此,此选项在支持任意 Python 函数的同时应该比标准 UDF 更有效(尤其是具有较低的 serde 开销)。


* A number of other higher order functions are also supported, including, but not limited to filter and aggregate。参见示例

  • How to slice and sum elements of array column?
  • Filter array column content
  • Spark Scala row-wise average by handling null.
  • .