如何在 pyspark 中使用 groupby 创建转换矩阵

How to create transition matrix with groupby in pyspark

我有一个 pyspark 数据框,看起来像这样

import pandas as pd
so = pd.DataFrame({'id': ['a','a','a','a','b','b','b','b','c','c','c','c'],
                   'time': [1,2,3,4,1,2,3,4,1,2,3,4],
                   'group':['A','A','A','A','A','A','A','A','B','B','B','B'],
                   'value':['S','C','C','C', 'S','C','H', 'H', 'S','C','C','C']})

df_so = spark.createDataFrame(so)
df_so.show()

+---+----+-----+-----+
| id|time|group|value|
+---+----+-----+-----+
|  a|   1|    A|    S|
|  a|   2|    A|    C|
|  a|   3|    A|    C|
|  a|   4|    A|    C|
|  b|   1|    A|    S|
|  b|   2|    A|    C|
|  b|   3|    A|    H|
|  b|   4|    A|    H|
|  c|   1|    B|    S|
|  c|   2|    B|    C|
|  c|   3|    B|    C|
|  c|   4|    B|    C|
+---+----+-----+-----+

我想通过 group

创建 value 的“转换矩阵”

转移矩阵表示例如从 value Svalue C 每个 idtime 进步。

示例:

对于group A

我们可以分别对group B

做同样的事情

有没有办法在 pyspark 中做到这一点?

首先我使用lag为每一行制作转换的源列(转换的左侧),然后通过source & value(目标) 除以总数。

lagw = Window.partitionBy(['group', 'id']).orderBy('time')
frqw = Window.partitionBy(['group', 'source', 'value'])
ttlw = Window.partitionBy('group')

df = (df.withColumn('source', F.lag('value').over(lagw))
  .withColumn('transition_p', F.count('source').over(frqw) / F.count('source').over(ttlw)))

df.show()

# +---+----+-----+-----+------+------------+
# | id|time|group|value|source|transition_p|
# +---+----+-----+-----+------+------------+
# |  c|   1|    B|    S|  null|         0.0|
# |  c|   3|    B|    C|     C| 0.666666666|
# |  c|   4|    B|    C|     C| 0.666666666|
# |  c|   2|    B|    C|     S| 0.333333333|
# |  b|   1|    A|    S|  null|         0.0|
# .....

如果最后我明白你喜欢什么,

(df.filter(df.group == 'A')
 .groupby('source')
 .pivot('value')
 .agg(F.first('transition_p'))
).show()

# +------+---------+---------+---------+
# |source|        C|        H|        S| 
# +------+---------+---------+---------+
# |  null|     null|     null|      0.0|
# |     C|0.3333333|0.1666666|     null|
# |     S|0.3333333|     null|     null|
# |     H|     null|0.1666666|     null|
# +------+---------+---------+---------+