提问者:小点点

按列"grp"分组并压缩DataFrame-(对于按列"ord"排序的每个列,取最后一个不是空值)


假设我有以下数据帧:

+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
|  1|    null|  3|null|  11|
|  2|    null|  2| xxx|  22|
|  1|    null|  1| yyy|null|
|  2|    null|  7|null|  33|
|  1|    null| 12|null|null|
|  2|    null| 19|null|  77|
|  1|    null| 10| s13|null|
|  2|    null| 11| a23|null|
+---+--------+---+----+----+

下面是带有注释的同一示例DF,按grpord排序:

scala> df.orderBy("grp", "ord").show
+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
|  1|    null|  1| yyy|null|
|  1|    null|  3|null|  11|   # grp:1 - last value for `col2` (11)
|  1|    null| 10| s13|null|   # grp:1 - last value for `col1` (s13)
|  1|    null| 12|null|null|   # grp:1 - last values for `null_col`, `ord`
|  2|    null|  2| xxx|  22|   
|  2|    null|  7|null|  33|   
|  2|    null| 11| a23|null|   # grp:2 - last value for `col1` (a23)
|  2|    null| 19|null|  77|   # grp:2 - last values for `null_col`, `ord`, `col2`
+---+--------+---+----+----+

我想压缩它。即按列"grp"对其进行分组,对于每个组,按"ord"列对行进行排序,并在每列中获取最后一个非空值(如果有的话)。

+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
|  1|    null| 12| s13|  11|
|  2|    null| 19| a23|  77|
+---+--------+---+----+----+

我看到了以下类似的问题:

  • 如何选择每组的第一行?
  • 如何在组中找到第一个非空值?(使用数据集api进行二次排序)

但是我真正的DataFrame有超过250列,所以我需要一个解决方案,我不必显式指定所有列。

我不能把我的头绕在它周围。。。

MCVE:如何创建示例数据帧:

>

  • 创建本地文件/tmp/data.txt,并复制和粘贴DataFrame的上下文(如上所述)
  • 定义函数readSparkOutput()
  • 解析/tmp/data.txt到DataFrame:

    val df = readSparkOutput("file:///tmp/data.txt")
    

    更新:我认为它应该类似于以下SQL:

    SELECT
      grp, ord, null_col, col1, col2
    FROM (
        SELECT
          grp,
          ord,
          FIRST(null_col) OVER (PARTITION BY grp ORDER BY ord DESC) as null_col,
          FIRST(col1) OVER (PARTITION BY grp ORDER BY ord DESC) as col1,
          FIRST(col2) OVER (PARTITION BY grp ORDER BY ord DESC) as col2,
          ROW_NUMBER() OVER (PARTITION BY grp ORDER BY ord DESC) as rn
        FROM table_name) as v
    WHERE v.rn = 1;
    

    我们如何动态生成这样的Spark查询?

    我尝试了以下简化方法:

    import org.apache.spark.sql.expressions.Window
    
    val win = Window
      .partitionBy("grp")
      .orderBy($"ord".desc)
    
    val cols = df.columns.map(c => first(c, ignoreNulls=true).over(win).as(c))
    

    它产生:

    scala> cols
    res23: Array[org.apache.spark.sql.Column] = Array(first(grp, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `grp`, first(null_col, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `null_col`, first(ord, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `ord`, first(col1, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `col1`, first(col2, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `col2`)
    

    但我无法将其传递给df。选择

    scala> df.select(cols.head, cols.tail: _*).show
    <console>:34: error: no `: _*' annotation allowed here
    (such annotations are only allowed in arguments to *-parameters)
           df.select(cols.head, cols.tail: _*).show
    

    另一次尝试:

    scala> df.select(cols.map(col): _*).show
    <console>:34: error: type mismatch;
     found   : String => org.apache.spark.sql.Column
     required: org.apache.spark.sql.Column => ?
           df.select(cols.map(col): _*).show
    

  • 共3个答案

    匿名用户

    请考虑以下方法,该方法将Windows函数last(c)应用于所选列中的每个列,并按grp的"d"排序;然后是一个group pBy("grp")来获取第一个agg(colFcnMap)结果:

    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.expressions.Window
    
    val df0 = Seq(
      (1, 3, None, Some(11)),
      (2, 2, Some("aaa"), Some(22)),
      (1, 1, Some("s12"), None),
      (2, 7, None, Some(33)),
      (1, 12, None, None),
      (2, 19, None, Some(77)),
      (1, 10, Some("s13"), None),
      (2, 11, Some("a23"), None)
    ).toDF("grp", "ord", "col1", "col2")
    
    val df = df0.withColumn("null_col", lit(null))
    
    df.orderBy("grp", "ord").show
    // +---+---+----+----+--------+
    // |grp|ord|col1|col2|null_col|
    // +---+---+----+----+--------+
    // |  1|  1| s12|null|    null|
    // |  1|  3|null|  11|    null|
    // |  1| 10| s13|null|    null|
    // |  1| 12|null|null|    null|
    // |  2|  2| aaa|  22|    null|
    // |  2|  7|null|  33|    null|
    // |  2| 11| a23|null|    null|
    // |  2| 19|null|  77|    null|
    // +---+---+----+----+--------+
    
    val win = Window.partitionBy("grp").orderBy("ord").
      rowsBetween(0, Window.unboundedFollowing)
    
    val nonAggCols = Array("grp")
    val cols = df.columns.diff(nonAggCols)  // Columns to be aggregated
    
    val colFcnMap = cols.zip(Array.fill(cols.size)("first")).toMap
    // colFcnMap: scala.collection.immutable.Map[String,String] =
    //   Map(ord -> first, col1 -> first, col2 -> first, null_col -> first)
    
    cols.foldLeft(df)((acc, c) =>
        acc.withColumn(c, last(c, ignoreNulls=true).over(win))
      ).
      groupBy("grp").agg(colFcnMap).
      select(col("grp") :: colFcnMap.toList.map{case (c, f) => col(s"$f($c)").as(c)}: _*).
      show
    // +---+---+----+----+--------+
    // |grp|ord|col1|col2|null_col|
    // +---+---+----+----+--------+
    // |  1| 12| s13|  11|    null|
    // |  2| 19| a23|  77|    null|
    // +---+---+----+----+--------+
    

    请注意,最后一个select用于从聚合列名中剥离函数名(在本例中为first())。

    匿名用户

    我已经解决了一些问题,下面是代码和输出

    import org.apache.spark.sql.functions._
    import spark.implicits._
    
    val df0 = Seq(
      (1, 3, None, Some(11)),
      (2, 2, Some("aaa"), Some(22)),
      (1, 1, Some("s12"), None),
      (2, 7, None, Some(33)),
      (1, 12, None, None),
      (2, 19, None, Some(77)),
      (1, 10, Some("s13"), None),
      (2, 11, Some("a23"), None)
    ).toDF("grp", "ord", "col1", "col2")
    
    df0.show()
    
    //+---+---+----+----+
    //|grp|ord|col1|col2|
    //+---+---+----+----+
    //|  1|  3|null|  11|
    //|  2|  2| aaa|  22|
    //|  1|  1| s12|null|
    //|  2|  7|null|  33|
    //|  1| 12|null|null|
    //|  2| 19|null|  77|
    //|  1| 10| s13|null|
    //|  2| 11| a23|null|
    //+---+---+----+----+
    

    排序前2列的数据

    val df1 = df0.select("grp", "ord", "col1", "col2").orderBy("grp", "ord")
    
    df1.show()
    
    //+---+---+----+----+
    //|grp|ord|col1|col2|
    //+---+---+----+----+
    //|  1|  1| s12|null|
    //|  1|  3|null|  11|
    //|  1| 10| s13|null|
    //|  1| 12|null|null|
    //|  2|  2| aaa|  22|
    //|  2|  7|null|  33|
    //|  2| 11| a23|null|
    //|  2| 19|null|  77|
    //+---+---+----+----+
    
    val df2 = df1.groupBy("grp").agg(max("ord").alias("ord"),collect_set("col1").alias("col1"),collect_set("col2").alias("col2"))
    
    val df3 = df2.withColumn("new_col1",$"col1".apply(size($"col1").minus(1))).withColumn("new_col2",$"col2".apply(size($"col2").minus(1)))
    
    df3.show()
    
    //+---+---+----------+------------+--------+--------+
    //|grp|ord|      col1|        col2|new_col1|new_col2|
    //+---+---+----------+------------+--------+--------+
    //|  1| 12|[s12, s13]|        [11]|     s13|      11|
    //|  2| 19|[aaa, a23]|[33, 22, 77]|     a23|      77|
    //+---+---+----------+------------+--------+--------+
    

    您可以使用删除不需要的列。删除(“列名称”)

    匿名用户

    因此,这里我们按a分组,并选择组中所有其他列的最大值:

    scala> val df = List((1,2,11), (1,1,1), (2,1,4), (2,3,5)).toDF("a", "b", "c")
    df: org.apache.spark.sql.DataFrame = [a: int, b: int ... 1 more field]
    
    scala> val aggCols = df.schema.map(_.name).filter(_ != "a").map(colName => sum(col(colName)).alias(s"max_$colName"))
    aggCols: Seq[org.apache.spark.sql.Column] = List(sum(b) AS `max_b`, sum(c) AS `max_c`)
    
    scala> df.groupBy(col("a")).agg(aggCols.head, aggCols.tail: _*)
    res0: org.apache.spark.sql.DataFrame = [a: int, max_b: bigint ... 1 more field]