具有混合值类型的 MapType 的 PySpark UDF

PySpark UDF of MapType with mixed value type

我有一个 JSON 这样的输入


    {
      "1": {
        "id": 1,
        "value": 5
      },
      "2": {
        "id": 2,
        "list": {
          "10": {
            "id": 10
          },
          "11": {
            "id": 11
          },
          "20": {
            "id": 20
          }
        }
      },
      "3": {
        "id": 3,
        "key": "a"
      }
    }

我需要合并 3 列并为每列提取所需的值,这是我需要的输出:


    {
      "out": {
        "1": 5,
        "2": [10, 11, 20],
        "3": "a"
      }
    }

我试图创建一个 UDF 将这 3 列转换为 1 列,但我不知道如何使用混合值类型定义 MapType() - IntegerType()ArrayType(IntegerType())StringType()分别。

提前致谢!

您需要使用 StructType 而不是 MapType 来定义 UDF 的结果类型,如下所示:

from pyspark.sql.types import *

udf_result = StructType([
    StructField('1', IntegerType()),
    StructField('2', ArrayType(StringType())),
    StructField('3', StringType())
])

MapType() 用于(键,值)对定义而不是嵌套数据框。您正在寻找的是 StructType()

您可以直接使用 createDataFrame 加载它,但您必须传递一个模式,所以这种方式更容易:

import json

data_json = {
      "1": {
        "id": 1,
        "value": 5
      },
      "2": {
        "id": 2,
        "list": {
          "10": {
            "id": 10
          },
          "11": {
            "id": 11
          },
          "20": {
            "id": 20
          }
        }
      },
      "3": {
        "id": 3,
        "key": "a"
      }
    }
a=[json.dumps(data_json)]
jsonRDD = sc.parallelize(a)
df = spark.read.json(jsonRDD)
df.printSchema()

    root
     |-- 1: struct (nullable = true)
     |    |-- id: long (nullable = true)
     |    |-- value: long (nullable = true)
     |-- 2: struct (nullable = true)
     |    |-- id: long (nullable = true)
     |    |-- list: struct (nullable = true)
     |    |    |-- 10: struct (nullable = true)
     |    |    |    |-- id: long (nullable = true)
     |    |    |-- 11: struct (nullable = true)
     |    |    |    |-- id: long (nullable = true)
     |    |    |-- 20: struct (nullable = true)
     |    |    |    |-- id: long (nullable = true)
     |-- 3: struct (nullable = true)
     |    |-- id: long (nullable = true)
     |    |-- key: string (nullable = true)

现在访问嵌套数据框。请注意,“2”列比其他列嵌套更多:

nested_cols = ["2"]
cols = ["1", "3"]
import pyspark.sql.functions as psf
df = df.select(
    cols + [psf.array(psf.col(c + ".list.*")).alias(c) for c in nested_cols]
)
df = df.select(
    [df[c].id.alias(c) for c in df.columns]
)

    root
     |-- 1: long (nullable = true)
     |-- 3: long (nullable = true)
     |-- 2: array (nullable = false)
     |    |-- element: long (containsNull = true)

这不完全是您的最终输出,因为您希望它嵌套在 "out" 列中:

import pyspark.sql.functions as psf
df.select(psf.struct("*").alias("out")).printSchema()

    root
     |-- out: struct (nullable = false)
     |    |-- 1: long (nullable = true)
     |    |-- 3: long (nullable = true)
     |    |-- 2: array (nullable = false)
     |    |    |-- element: long (containsNull = true)

终于回到JSON:

df.toJSON().first()

    '{"1":1,"3":3,"2":[10,11,20]}'