提交 7d05d02b 编写于 作者: W Wenchen Fan

[SPARK-13637][SQL] use more information to simplify the code in Expand builder

## What changes were proposed in this pull request?

The code in `Expand.apply` can be simplified by existing information:

* the `groupByExprs` parameter are all `Attribute`s
* the `child` parameter is a `Project` that append aliased group by expressions to its child's output

## How was this patch tested?

by existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11485 from cloud-fan/expand.
上级 9e86e6ef
......@@ -298,12 +298,10 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}
val child = Project(x.child.output ++ groupByAliases, x.child)
Aggregate(
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
aggregations,
Expand(x.bitmasks, groupByAttributes, gid, child))
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
}
}
......
......@@ -449,21 +449,21 @@ private[sql] object Expand {
* Extract attribute set according to the grouping id.
*
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @param attrs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(
private def buildNonSelectAttrSet(
bitmask: Int,
exprs: Seq[Expression]): ArrayBuffer[Expression] = {
val set = new ArrayBuffer[Expression](2)
attrs: Seq[Attribute]): AttributeSet = {
val nonSelect = new ArrayBuffer[Attribute]()
var bit = exprs.length - 1
var bit = attrs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1)
bit -= 1
}
set
AttributeSet(nonSelect)
}
/**
......@@ -471,13 +471,15 @@ private[sql] object Expand {
* multiple output rows for a input row.
*
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param groupByAliases The aliased original group by expressions
* @param groupByAttrs The attributes of aliased group by expressions
* @param gid Attribute of the grouping id
* @param child Child operator
*/
def apply(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
groupByAliases: Seq[Alias],
groupByAttrs: Seq[Attribute],
gid: Attribute,
child: LogicalPlan): Expand = {
// Create an array of Projections for the child projection, and replace the projections'
......@@ -485,27 +487,21 @@ private[sql] object Expand {
// are not set for this grouping set (according to the bit mask).
val projections = bitmasks.map { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs)
(child.output :+ gid).map(expr => expr transformDown {
// TODO this causes a problem when a column is used both for grouping and aggregation.
case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
child.output ++ groupByAttrs.map { attr =>
if (nonSelectedGroupAttrSet.contains(attr)) {
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})
}
val output = child.output.map { attr =>
if (groupByExprs.exists(_.semanticEquals(attr))) {
attr.withNullability(true)
} else {
attr
}
Literal.create(null, attr.dataType)
} else {
attr
}
// groupingId is the last output, here we use the bit mask as the concrete value for it.
} :+ Literal.create(bitmask, IntegerType)
}
Expand(projections, output :+ gid, child)
val output = child.output ++ groupByAttrs :+ gid
Expand(projections, output, Project(child.output ++ groupByAliases, child))
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册