From 7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Mar 2016 23:34:42 +0800 Subject: [PATCH] [SPARK-13637][SQL] use more information to simplify the code in Expand builder ## What changes were proposed in this pull request? The code in `Expand.apply` can be simplified by existing information: * the `groupByExprs` parameter are all `Attribute`s * the `child` parameter is a `Project` that append aliased group by expressions to its child's output ## How was this patch tested? by existing tests. Author: Wenchen Fan Closes #11485 from cloud-fan/expand. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../plans/logical/basicOperators.scala | 48 +++++++++---------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b5fa372643..268d7f21e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -298,12 +298,10 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - val child = Project(x.child.output ++ groupByAliases, x.child) - Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, aggregations, - Expand(x.bitmasks, groupByAttributes, gid, child)) + Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 411594c951..3bc246a32d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -449,21 +449,21 @@ private[sql] object Expand { * Extract attribute set according to the grouping id. * * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence + * @param attrs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet( + private def buildNonSelectAttrSet( bitmask: Int, - exprs: Seq[Expression]): ArrayBuffer[Expression] = { - val set = new ArrayBuffer[Expression](2) + attrs: Seq[Attribute]): AttributeSet = { + val nonSelect = new ArrayBuffer[Attribute]() - var bit = exprs.length - 1 + var bit = attrs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1) + if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1) bit -= 1 } - set + AttributeSet(nonSelect) } /** @@ -471,13 +471,15 @@ private[sql] object Expand { * multiple output rows for a input row. * * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions + * @param groupByAliases The aliased original group by expressions + * @param groupByAttrs The attributes of aliased group by expressions * @param gid Attribute of the grouping id * @param child Child operator */ def apply( bitmasks: Seq[Int], - groupByExprs: Seq[Expression], + groupByAliases: Seq[Alias], + groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { // Create an array of Projections for the child projection, and replace the projections' @@ -485,27 +487,21 @@ private[sql] object Expand { // are not set for this grouping set (according to the bit mask). val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - (child.output :+ gid).map(expr => expr transformDown { - // TODO this causes a problem when a column is used both for grouping and aggregation. - case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) => + child.output ++ groupByAttrs.map { attr => + if (nonSelectedGroupAttrSet.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - } - val output = child.output.map { attr => - if (groupByExprs.exists(_.semanticEquals(attr))) { - attr.withNullability(true) - } else { - attr - } + Literal.create(null, attr.dataType) + } else { + attr + } + // groupingId is the last output, here we use the bit mask as the concrete value for it. + } :+ Literal.create(bitmask, IntegerType) } - Expand(projections, output :+ gid, child) + val output = child.output ++ groupByAttrs :+ gid + Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } -- GitLab