在 Spark DataFrame 中使用 UDT 时出现异常

Exception when using UDT in Spark DataFrame

我正在尝试在 spark sql 中创建用户定义的类型,但我收到: 即使使用他们的示例,com.ubs.ged.risk.stdout.spark.ExamplePointUDT 也不能转换为 org.apache.spark.sql.types.StructType。有人做过这个吗?

我的代码:

test("udt serialisation") {
    val points = Seq(new ExamplePoint(1.3, 1.6), new ExamplePoint(1.3, 1.8))
    val df = SparkContextForStdout.context.parallelize(points).toDF()
}

@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) 
case class ExamplePoint(val x: Double, val y: Double)

/**
 * User-defined type for [[ExamplePoint]].
 */
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

  override def sqlType: DataType = ArrayType(DoubleType, false)

  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

  override def serialize(obj: Any): Seq[Double] = {
    obj match {
      case p: ExamplePoint =>
        Seq(p.x, p.y)
    }
  }

  override def deserialize(datum: Any): ExamplePoint = {
    datum match {
      case values: Seq[_] =>
        val xy = values.asInstanceOf[Seq[Double]]
        assert(xy.length == 2)
        new ExamplePoint(xy(0), xy(1))
      case values: util.ArrayList[_] =>
        val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
        new ExamplePoint(xy(0), xy(1))
    }
  }

  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]

}

有用的堆栈跟踪是这样的:

com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
java.lang.ClassCastException: com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
    at org.apache.spark.sql.SQLContext.createDataFrame(SQLContext.scala:316)
    at org.apache.spark.sql.SQLContext$implicits$.rddToDataFrameHolder(SQLContext.scala:254)

似乎 UDT 需要在另一个 class 内部使用才能工作(作为字段的类型)。直接使用它的一种解决方案是将其包装到 Tuple1:

  test("udt serialisation") {
    val points = Seq(new Tuple1(new ExamplePoint(1.3, 1.6)), new Tuple1(new ExamplePoint(1.3, 1.8)))
    val df = SparkContextForStdout.context.parallelize(points).toDF()
    df.collect().foreach(println(_))
  }