PySpark 余弦相似变换器

PySpark cosin-similarity Transformer

我有一个包含两列的 DataFrame,每列包含向量,例如

+-------------+------------+
|     v1      |     v2     |
+-------------+------------+
| [1,1.2,0.4] | [2,0.4,5]  |
| [1,.2,0.6]  | [2,.2,5]   |
| .           | .          |
| .           | .          |
| .           | .          |
| [0,1.2,.6]  | [2,.2,0.4] |
+-------------+------------+

我想在此 DataFrame 中添加另一列,其中包含每行中两个向量之间的余弦相似度。

不知道这里有任何可以直接计算余弦相似度的变换。 您可以为此类功能编写自己的 udf

from pyspark.ml.linalg import Vectors, DenseVector
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import *

v = [(DenseVector([1,1.2,0.4]), DenseVector([2,0.4,5])),
    (DenseVector([1,2,0.6]), DenseVector([2,0.2,5])),
    (DenseVector([0,1.2,0.6]), DenseVector([2,0.2,0.4]))]

dfv1 = spark.createDataFrame(v, ['v1', 'v2'])
dfv1 = dfv1.withColumn('v1v2', F.struct([F.col('v1'), F.col('v2')]))
dfv1.show(truncate=False)

这是具有组合向量的 DataFrame:

+-------------+-------------+------------------------------+
|v1           |v2           |v1v2                          |
+-------------+-------------+------------------------------+
|[1.0,1.2,0.4]|[2.0,0.4,5.0]|[[1.0,1.2,0.4], [2.0,0.4,5.0]]|
|[1.0,2.0,0.6]|[2.0,0.2,5.0]|[[1.0,2.0,0.6], [2.0,0.2,5.0]]|
|[0.0,1.2,0.6]|[2.0,0.2,0.4]|[[0.0,1.2,0.6], [2.0,0.2,0.4]]|
+-------------+-------------+------------------------------+

现在我们可以为余弦相似度定义udf

dot_prod_udf = F.udf(lambda v: float(v[0].dot(v[1])/v[0].norm(None)/v[1].norm(None)), FloatType())
dfv1 = dfv1.withColumn('cosine_similarity', dot_prod_udf(dfv1['v1v2']))
dfv1.show(truncate=False)

最后一列显示余弦相似度:

+-------------+-------------+------------------------------+-----------------+
|v1           |v2           |v1v2                          |cosine_similarity|
+-------------+-------------+------------------------------+-----------------+
|[1.0,1.2,0.4]|[2.0,0.4,5.0]|[[1.0,1.2,0.4], [2.0,0.4,5.0]]|0.51451445       |
|[1.0,2.0,0.6]|[2.0,0.2,5.0]|[[1.0,2.0,0.6], [2.0,0.2,5.0]]|0.4328257        |
|[0.0,1.2,0.6]|[2.0,0.2,0.4]|[[0.0,1.2,0.6], [2.0,0.2,0.4]]|0.17457432       |
+-------------+-------------+------------------------------+-----------------+