提交 63875b3c 编写于 作者: G godfreyhe

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalLocalSortAggregate,...

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalLocalSortAggregate, and make BatchPhysicalLocalSortAggregate only extended from FlinkPhysicalRel

This closes #14562
上级 29c81fe1
...@@ -623,7 +623,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { ...@@ -623,7 +623,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
} else { } else {
null null
} }
case agg: BatchExecLocalSortAggregate => case agg: BatchPhysicalLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchPhysicalSortAggregate if agg.isMerge => case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
......
...@@ -18,20 +18,12 @@ ...@@ -18,20 +18,12 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, SortAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecSortAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil} import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil}
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelDistribution.Type import org.apache.calcite.rel.RelDistribution.Type
...@@ -49,7 +41,7 @@ import scala.collection.JavaConversions._ ...@@ -49,7 +41,7 @@ import scala.collection.JavaConversions._
* *
* @see [[BatchPhysicalGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
class BatchExecLocalSortAggregate( class BatchPhysicalLocalSortAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
...@@ -67,11 +59,10 @@ class BatchExecLocalSortAggregate( ...@@ -67,11 +59,10 @@ class BatchExecLocalSortAggregate(
auxGrouping, auxGrouping,
aggCallToAggFunction, aggCallToAggFunction,
isMerge = false, isMerge = false,
isFinal = false) isFinal = false) {
with LegacyBatchExecNode[RowData] {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalSortAggregate( new BatchPhysicalLocalSortAggregate(
cluster, cluster,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
...@@ -136,52 +127,25 @@ class BatchExecLocalSortAggregate( ...@@ -136,52 +127,25 @@ class BatchExecLocalSortAggregate(
Some(copy(newProvidedTraits, Seq(newInput))) Some(copy(newProvidedTraits, Seq(newInput)))
} }
//~ ExecNode methods ----------------------------------------------------------- override def translateToExecNode(): ExecNode[_] = {
new BatchExecSortAggregate(
override protected def translateToPlanInternal( grouping,
planner: BatchPlanner): Transformation[RowData] = { auxGrouping,
val input = getInputNodes.get(0).translateToPlan(planner) getAggCallList.toArray,
.asInstanceOf[Transformation[RowData]] FlinkTypeFactory.toLogicalRowType(inputRowType),
val ctx = CodeGeneratorContext(planner.getTableConfig) false, // isMerge is always false
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) false, // isFinal is always false
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) getInputEdge,
FlinkTypeFactory.toLogicalRowType(getRowType),
val aggInfos = transformToBatchAggregateInfoList( getRelDetailedDescription
FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList) )
val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
SortAggCodeGenerator.genWithKeys(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal)
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
0)
} }
override def getInputEdges: util.List[ExecEdge] = { private def getInputEdge: ExecEdge = {
if (grouping.length == 0) { if (grouping.length == 0) {
List( ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build()
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())
} else { } else {
List(ExecEdge.DEFAULT) ExecEdge.DEFAULT
} }
} }
} }
...@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList ...@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase, BatchPhysicalLocalHashAggregate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalGroupAggregateBase, BatchPhysicalLocalHashAggregate, BatchPhysicalLocalSortAggregate}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.AggregatePhaseStrategy
import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy
...@@ -219,7 +219,7 @@ trait BatchPhysicalAggRuleBase { ...@@ -219,7 +219,7 @@ trait BatchPhysicalAggRuleBase {
auxGrouping, auxGrouping,
aggCallToAggFunction) aggCallToAggFunction)
} else { } else {
new BatchExecLocalSortAggregate( new BatchPhysicalLocalSortAggregate(
cluster, cluster,
traitSet, traitSet,
input, input,
......
...@@ -39,7 +39,7 @@ import scala.collection.JavaConversions._ ...@@ -39,7 +39,7 @@ import scala.collection.JavaConversions._
* BatchPhysicalSortAggregate (global) * BatchPhysicalSortAggregate (global)
* +- Sort (exists if group keys are not empty) * +- Sort (exists if group keys are not empty)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- BatchExecLocalSortAggregate (local) * +- BatchPhysicalLocalSortAggregate (local)
* +- Sort (exists if group keys are not empty) * +- Sort (exists if group keys are not empty)
* +- input of agg * +- input of agg
* }}} * }}}
...@@ -88,7 +88,7 @@ class BatchPhysicalSortAggRule ...@@ -88,7 +88,7 @@ class BatchPhysicalSortAggRule
// create two-phase agg if possible // create two-phase agg if possible
if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) { if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) {
// create BatchExecLocalSortAggregate // create BatchPhysicalLocalSortAggregate
var localRequiredTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) var localRequiredTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
if (agg.getGroupCount != 0) { if (agg.getGroupCount != 0) {
val sortCollation = createRelCollation(groupSet) val sortCollation = createRelCollation(groupSet)
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalSort, BatchPhysicalSortAggregate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalLocalSortAggregate, BatchPhysicalSort, BatchPhysicalSortAggregate}
import org.apache.calcite.plan.RelOptRule._ import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand}
...@@ -56,7 +56,7 @@ abstract class RemoveRedundantLocalSortAggRule( ...@@ -56,7 +56,7 @@ abstract class RemoveRedundantLocalSortAggRule(
private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchPhysicalSortAggregate private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchPhysicalSortAggregate
private[table] def getOriginalLocalAgg(call: RelOptRuleCall): BatchExecLocalSortAggregate private[table] def getOriginalLocalAgg(call: RelOptRuleCall): BatchPhysicalLocalSortAggregate
private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode
...@@ -64,7 +64,7 @@ abstract class RemoveRedundantLocalSortAggRule( ...@@ -64,7 +64,7 @@ abstract class RemoveRedundantLocalSortAggRule(
class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSortAggRule( class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSortAggRule(
operand(classOf[BatchPhysicalSortAggregate], operand(classOf[BatchPhysicalSortAggregate],
operand(classOf[BatchExecLocalSortAggregate], operand(classOf[BatchPhysicalLocalSortAggregate],
operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))),
"RemoveRedundantLocalSortAggWithoutSortRule") { "RemoveRedundantLocalSortAggWithoutSortRule") {
...@@ -74,8 +74,8 @@ class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSor ...@@ -74,8 +74,8 @@ class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSor
} }
override private[table] def getOriginalLocalAgg( override private[table] def getOriginalLocalAgg(
call: RelOptRuleCall): BatchExecLocalSortAggregate = { call: RelOptRuleCall): BatchPhysicalLocalSortAggregate = {
call.rels(1).asInstanceOf[BatchExecLocalSortAggregate] call.rels(1).asInstanceOf[BatchPhysicalLocalSortAggregate]
} }
override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = { override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = {
...@@ -87,7 +87,7 @@ class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSor ...@@ -87,7 +87,7 @@ class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSor
class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAggRule( class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAggRule(
operand(classOf[BatchPhysicalSortAggregate], operand(classOf[BatchPhysicalSortAggregate],
operand(classOf[BatchPhysicalSort], operand(classOf[BatchPhysicalSort],
operand(classOf[BatchExecLocalSortAggregate], operand(classOf[BatchPhysicalLocalSortAggregate],
operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any)))), operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any)))),
"RemoveRedundantLocalSortAggWithSortRule") { "RemoveRedundantLocalSortAggWithSortRule") {
...@@ -97,8 +97,8 @@ class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAg ...@@ -97,8 +97,8 @@ class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAg
} }
override private[table] def getOriginalLocalAgg( override private[table] def getOriginalLocalAgg(
call: RelOptRuleCall): BatchExecLocalSortAggregate = { call: RelOptRuleCall): BatchPhysicalLocalSortAggregate = {
call.rels(2).asInstanceOf[BatchExecLocalSortAggregate] call.rels(2).asInstanceOf[BatchPhysicalLocalSortAggregate]
} }
override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = { override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册