如何在pyspark中增加字段?

how to increment field in pyspark?

我必须根据条件生成一个增量数字,例如,我有以下数据框:

+-- |-----------+------+---  
seq |cod         |trans |ant 
+--+-------------+------+--- 
01 |05           |00    |1   
02 |05           |01    |00  
03 |05           |02    |01  
04 |05           |05    |02  
05 |05           |00    |05  
06 |05           |01    |00  
07 |05           |02    |01  
08 |05           |05    |02  
09 |05           |07    |05  
10 |05           |00    |07  
11 |05           |01    |00  
12 |05           |02    |01  
13 |05           |05    |02  

我使用:

global cont
df1 = df.withColumn("id",when(col("trans ").cast("int") < col("ant").cast("int"),cont+1).otherwise(cont))

我得到以下输出:

+-- |-----------+------+--- +---
seq |cod          |trans |ant |id 
+--+-------------+------+--- +---
01 |05           |00    |1    |1  
02 |05           |01    |00   |0  
03 |05           |02    |01   |0  
04 |05           |05    |02   |0 
05 |05           |00    |05   |1  
06 |05           |01    |00   |0  
07 |05           |02    |01   |0  
08 |05           |05    |02   |0  
09 |05           |07    |05   |0  
10 |05           |00    |07   |1  
11 |05           |01    |00   |0  
12 |05           |02    |01   |0 
13 |05           |05    |02   |0  

但我希望是这样的:

+-- |-----------+------+--- +---
seq |cod          |trans |ant |id 
+--+-------------+------+--- +---
01 |05           |00    |1    |1  
02 |05           |01    |00   |1  
03 |05           |02    |01   |1  
04 |05           |05    |02   |1  
05 |05           |00    |05   |2  
06 |05           |01    |00   |2  
07 |05           |02    |01   |2  
08 |05           |05    |02   |2  
09 |05           |07    |05   |2  
10 |05           |00    |07   |3  
11 |05           |01    |00   |3  
12 |05           |02    |01   |3  
13 |05           |05    |02   |3  

有没有人有什么建议可以帮助我?

您需要一个累计总和,您可以在此处使用 window 函数:

from pyspark.sql import functions as F, Window as W
w = (W.partitionBy("cod").orderBy(F.col("seq").cast("int"))
     .rangeBetween(W.unboundedPreceding,W.currentRow))
df1 = df.withColumn("id",F.sum(
        F.when(F.col("trans").cast("int") < F.col("ant").cast("int"),1).otherwise(0)
                              ).over(w)
                   )

df1.show()
+---+---+-----+---+---+
|seq|cod|trans|ant| id|
+---+---+-----+---+---+
| 01| 05|   00|  1|  1|
| 02| 05|   01| 00|  1|
| 03| 05|   02| 01|  1|
| 04| 05|   05| 02|  1|
| 05| 05|   00| 05|  2|
| 06| 05|   01| 00|  2|
| 07| 05|   02| 01|  2|
| 08| 05|   05| 02|  2|
| 09| 05|   07| 05|  2|
| 10| 05|   00| 07|  3|
| 11| 05|   01| 00|  3|
| 12| 05|   02| 01|  3|
| 13| 05|   05| 02|  3|
+---+---+-----+---+---+