如何检查数组列是否在 PySpark 数据框中的另一个列数组内

How to check if array column is inside another column array in PySpark dataframe

假设我有以下情况

from pyspark.sql.types import *
schema = StructType([  # schema
    StructField("id", StringType(), True),
    StructField("ev", ArrayType(StringType()), True),
    StructField("ev2", ArrayType(StringType()), True),])
df = spark.createDataFrame([{"id": "se1", "ev": ["ev11", "ev12"], "ev2": ["ev11"]},
                            {"id": "se2", "ev": ["ev11"], "ev2": ["ev11", "ev12"]},
                            {"id": "se3", "ev": ["ev21"], "ev2": ["ev11", "ev12"]},
                            {"id": "se4", "ev": ["ev21", "ev22"], "ev2": ["ev21", "ev22"]}],
                           schema=schema)

这给了我:

df.show()
+---+------------+------------+
| id|          ev|         ev2|
+---+------------+------------+
|se1|[ev11, ev12]|      [ev11]|
|se2|      [ev11]|[ev11, ev12]|
|se3|      [ev21]|[ev11, ev12]|
|se4|[ev21, ev22]|[ev21, ev22]|
+---+------------+------------+

我想为 "ev" 列的内容在 "ev2" 列内的行创建一个新的布尔值列(或 select 只有真实情况),返回:

df_target.show()
+---+------------+------------+
| id|          ev|         ev2|
+---+------------+------------+
|se2|      [ev11]|[ev11, ev12]|
|se4|[ev21, ev22]|[ev21, ev22]|
+---+------------+------------+

或:

df_target.show()
+---+------------+------------+-------+
| id|          ev|         ev2|evInEv2|
+---+------------+------------+-------+
|se1|[ev11, ev12]|      [ev11]|  false|
|se2|      [ev11]|[ev11, ev12]|   true|
|se3|      [ev21]|[ev11, ev12]|  false|
|se4|[ev21, ev22]|[ev21, ev22]|   true|
+---+------------+------------+-------+

我尝试使用 isin 方法:

df.withColumn('evInEv2', df['ev'].isin(df['ev2'])).show()
+---+------------+------------+-------+
| id|          ev|         ev2|evInEv2|
+---+------------+------------+-------+
|se1|[ev11, ev12]|      [ev11]|  false|
|se2|      [ev11]|[ev11, ev12]|  false|
|se3|      [ev21]|[ev11, ev12]|  false|
|se4|[ev21, ev22]|[ev21, ev22]|   true|
+---+------------+------------+-------+

但看起来它只检查是否是同一个数组。

我也试过 pyspark.sql.functionsarray_contains 函数,但只接受一个对象而不接受数组来检查。

由于问题的措辞正确,我什至在搜索时遇到了困难。

谢谢!

这是一个使用 udf 的选项,我们在其中检查列 evev2 之间的差异长度。当结果数组的长度为 0 ,或者 ev 的所有元素都包含在 ev2 中时,我们 return True;否则 False.

def contains(x,y):
  z = len(set(x) - set(y))
  if z == 0:
    return True
  else:
    return False

contains_udf = udf(contains)
df.withColumn("evInEv2", contains_udf(df.ev,df.ev2)).show()
+---+------------+------------+-------+
| id|          ev|         ev2|evInEv2|
+---+------------+------------+-------+
|se1|[ev11, ev12]|      [ev11]|  false|
|se2|      [ev11]|[ev11, ev12]|   true|
|se3|      [ev21]|[ev11, ev12]|  false|
|se4|[ev21, ev22]|[ev21, ev22]|   true|
+---+------------+------------+-------+

或者,您可以使用

subsetOf=udf(lambda A,B: set(A).issubset(set(B)))
df.withColumn("evInEv2", subsetOf(df.ev,df.ev2)).show()

Spark >= 2.4.0 的另一个实现避免了 UDF 并使用内置 array_except:

from pyspark.sql.functions import size, array_except

def is_subset(a, b):
  return size(array_except(a, b)) == 0
  
df.withColumn("is_subset", is_subset(df.ev, df.ev2))

输出:

+---+------------+------------+---------+
| id|          ev|         ev2|is_subset|
+---+------------+------------+---------+
|se1|[ev11, ev12]|      [ev11]|    false|
|se2|      [ev11]|[ev11, ev12]|     true|
|se3|      [ev21]|[ev11, ev12]|    false|
|se4|[ev21, ev22]|[ev21, ev22]|     true|
+---+------------+------------+---------+

我创建了一个 Spark UDF:

from pyspark.sql.types import BooleanType
antecedent_inside_predictions = udf(lambda antecedent,prediction: all(elem in prediction for elem in antecedent), BooleanType()) 

然后在连接中使用它:

fp_predictions =  filtered_rules.join(personal_item_recos,antecedent_inside_predictions("antecedent", "item_predictions") )

请注意,我需要启用 crossJoins:

spark.conf.set('spark.sql.crossJoin.enabled', True)

(最后,我从项目中提取出我想要的特定项目如下:

fp_predictions = fp_predictions.withColumn("ITEM_SK", fp_predictions.consequent.getItem(0))