PySpark:使用窗口功能汇总数据框

我有一个数据框my_df,其中包含4列:

+----------------+---------------+--------+---------+
|         user_id|         domain|isp_flag|frequency|
+----------------+---------------+--------+---------+
|            josh|     wanadoo.fr|       1|       15|
|            josh|      random.it|       0|       12|
|        samantha|     wanadoo.fr|       1|       16|
|             bob|    eidsiva.net|       1|        5|
|             bob|      media.net|       0|        1|
|           dylan|    vodafone.it|       1|      448|
|           dylan|   somesite.net|       0|       20|
|           dylan|   yolosite.net|       0|       49|
|           dylan|      random.it|       0|        3|
|             don|    vodafone.it|       1|       39|
|             don|   popsugar.com|       0|       10|
|             don|      fabio.com|       1|       49|
+----------------+---------------+--------+---------+

这是我计划要做的-

  

查找所有user_id,其中domain的最大频率isp_flag=0的频率小于domain的最大频率isp_flag=1的25%

因此,在上面的示例中,我的output_df看起来像-

+----------------+---------------+--------+---------+
|         user_id|         domain|isp_flag|frequency|
+----------------+---------------+--------+---------+
|             bob|    eidsiva.net|       1|        5|
|             bob|      media.net|       0|        1|
|           dylan|    vodafone.it|       1|      448|
|           dylan|   yolosite.net|       0|       49|
|             don|      fabio.com|       1|       49|
|             don|   popsugar.com|       0|       10|
+----------------+---------------+--------+---------+

我认为我需要使用窗口函数来执行此操作,因此我尝试了以下操作,首先针对每个isp_flag=0-分别找到isp_flag=1user_id的最大频域-

>>> win_1 = Window().partitionBy("user_id","domain","isp_flag").orderBy((col("frequency").desc()))
>>> final_df = my_df.select("*",rank().over(win_1).alias("rank")).filter(col("rank")==1)
>>> final_df.show(5)   # this just gives me the original dataframe back

我在这里做错了什么?我如何到达上面打印的最后一个output_df

hao521ye_88 回答:PySpark:使用窗口功能汇总数据框

IIUC,您可以尝试以下操作:为每个具有isp_flag == 0或1的用户计算max_frequencies(max_0,max_1)。然后按条件max_0 < 0.25*max_1和加frequency in (max_1,max_0)进行过滤,以仅选择频率最高的记录。

from pyspark.sql import Window,functions as F

# set up the Window to calculate max_0 and max_1 for each user
# having isp_flag = 0 and 1 respectively
w1 = Window.partitionBy('user_id').rowsBetween(Window.unboundedPreceding,Window.unboundedFollowing)

df.withColumn('max_1',F.max(F.expr("IF(isp_flag==1,frequency,NULL)")).over(w1))\ 
  .withColumn('max_0',F.max(F.expr("IF(isp_flag==0,NULL)")).over(w1))\ 
  .where('max_0 < 0.25*max_1 AND frequency in (max_1,max_0)') \ 
  .show() 
+-------+------------+--------+---------+-----+-----+                           
|user_id|      domain|isp_flag|frequency|max_1|max_0|
+-------+------------+--------+---------+-----+-----+
|    don|popsugar.com|       0|       10|   49|   10|
|    don|   fabio.com|       1|       49|   49|   10|
|  dylan| vodafone.it|       1|      448|  448|   49|
|  dylan|yolosite.net|       0|       49|  448|   49|
|    bob| eidsiva.net|       1|        5|    5|    1|
|    bob|   media.net|       0|        1|    5|    1|
+-------+------------+--------+---------+-----+-----+

每个请求的一些解释:

  • WindowSpec w1设置为检查同一用户(partitionBy)的所有记录,以便F.max()函数将基于同一用户。

  • 我们使用IF(isp_flag==1,NULL)查找具有isp_flag == 1的行的频率,当isp_flag不是1时它返回NULL,因此在F.max()函数中被跳过。这是一个SQL表达式,因此我们需要F.expr()函数来运行它。

  • F.max(...).over(w1)将采用执行上述SQL表达式得出的结果的最大值。此计算基于窗口w1

本文链接:https://www.f2er.com/3142712.html

大家都在问