Pyspark dataframe:计算列中的唯一值,独立地与其他列中的值同时出现

Pyspark dataframe: Counting of unique values in a column, independently co-ocurring with values in other columns

我有一个 spark 数据框,其中包含从各种来源获得的两种类型的分子、调节剂和靶标(它们之间没有重叠)之间相互作用的数十亿个预测。我需要添加一列 包含预测给定 'Regulator' 和给定 'Target'.

至少一次交互的数字资源

换句话说,对于每对 'Regulator' 和 'Target',我正在尝试获取包含 'Regulator' 和 'Target' 值的来源数量,即使不是在一次互动中配对。

示例:

+---------+------+------+
|Regulator|Target|Source|
+---------+------+------+
|        m|     A|     x|
|        m|     B|     x|
|        m|     C|     z|
|        n|     A|     y|
|        n|     C|     x|
|        n|     C|     z|
+---------+------+------+

我要获取的是:

+---------+------+------+----------+
|Regulator|Target|Source|No.sources|
+---------+------+------+----------+
|        m|     A|     x|         1|
|        m|     B|     x|         1|
|        m|     C|     z|         2|
|        n|     A|     y|         2|
|        n|     C|     x|         2|
|        n|     C|     z|         2|
+---------+------+------+----------+

进一步说明:

第一行(m, A, x):

第二行(m, B, x)

第三行(m, C, z)

这是解决此问题的一种方法。对于每一行,创建 2 个新列:

  • 'RS''Regulator'
  • 的来源集
  • 'TS''Target'
  • 的来源集

那么你想要的输出就是这些集合的交集的长度。

考虑以下示例:

创建数据帧

from pyspark.sql Window
import pyspark.sql.functions as f
cols = ["Regulator", "Target", "Source"]
data = [
    ('m', 'A', 'x'),
    ('m', 'B', 'x'),
    ('m', 'C', 'z'),
    ('n', 'A', 'y'),
    ('n', 'C', 'x'),
    ('n', 'C', 'z')
]

df = sqlCtx.createDataFrame(data, cols)

创建新列

使用 pyspark.sql.functions.collect_set() and pyspark.sql.Window 计算 'Source' 列的不同值:

df = df.withColumn(
    'RS',
    f.collect_set(f.col('Source')).over(Window.partitionBy('Regulator'))
)

df = df.withColumn(
    'TS',
    f.collect_set(f.col('Source')).over(Window.partitionBy('Target'))
)
df.sort('Regulator', 'Target', 'Source').show()
#+---------+------+------+------+---------+
#|Regulator|Target|Source|    TS|       RS|
#+---------+------+------+------+---------+
#|        m|     A|     x|[y, x]|   [z, x]|
#|        m|     B|     x|   [x]|   [z, x]|
#|        m|     C|     z|[z, x]|   [z, x]|
#|        n|     A|     y|[y, x]|[y, z, x]|
#|        n|     C|     x|[z, x]|[y, z, x]|
#|        n|     C|     z|[z, x]|[y, z, x]|
#+---------+------+------+------+---------+

计算交点的长度

定义一个udf到return两个集合的交集长度,并用它来计算'No_sources'列。 (请注意,我在列名中使用 _ 而不是 .,因为它更易于使用 select()。)

intersection_length_udf = f.udf(lambda u, v: len(set(u) & set(v)), IntegerType())

df = df.withColumn('No_sources', intersection_length_udf(f.col('TS'), f.col('RS')))

df.select('Regulator', 'Target', 'Source', 'No_sources')\
    .sort('Regulator', 'Target', 'Source')\
    .show()
#+---------+------+------+----------+
#|Regulator|Target|Source|No_sources|
#+---------+------+------+----------+
#|        m|     A|     x|         1|
#|        m|     B|     x|         1|
#|        m|     C|     z|         2|
#|        n|     A|     y|         2|
#|        n|     C|     x|         2|
#|        n|     C|     z|         2|
#+---------+------+------+----------+