提交 d0d8377b 编写于 作者: G godfreyhe

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalHashAggregate, and...

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalHashAggregate, and make BatchExecHashAggregate only extended from ExecNode

This closes #14562
上级 146c68df
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.batch;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.batch.AggWithoutKeysCodeGenerator;
import org.apache.flink.table.planner.codegen.agg.batch.HashAggCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.generated.GeneratedOperator;
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rel.core.AggregateCall;
import java.util.Arrays;
import java.util.Collections;
/** Batch {@link ExecNode} for hash-based aggregate operator. */
public class BatchExecHashAggregate extends ExecNodeBase<RowData>
implements BatchExecNode<RowData> {
private final int[] grouping;
private final int[] auxGrouping;
private final AggregateCall[] aggCalls;
private final RowType aggInputRowType;
private final boolean isMerge;
private final boolean isFinal;
public BatchExecHashAggregate(
int[] grouping,
int[] auxGrouping,
AggregateCall[] aggCalls,
RowType aggInputRowType,
boolean isMerge,
boolean isFinal,
ExecEdge inputEdge,
RowType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.grouping = grouping;
this.auxGrouping = auxGrouping;
this.aggCalls = aggCalls;
this.aggInputRowType = aggInputRowType;
this.isMerge = isMerge;
this.isFinal = isFinal;
}
@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner);
final RowType inputRowType = (RowType) inputNode.getOutputType();
final RowType outputRowType = (RowType) getOutputType();
final TableConfig config = planner.getTableConfig();
final CodeGeneratorContext ctx = new CodeGeneratorContext(config);
final AggregateInfoList aggInfos =
AggregateUtil.transformToBatchAggregateInfoList(
aggInputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
null,
null);
final long managedMemory;
final GeneratedOperator<OneInputStreamOperator<RowData, RowData>> generatedOperator;
if (grouping.length == 0) {
managedMemory = 0L;
generatedOperator =
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx,
planner.getRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
isMerge,
isFinal,
"NoGrouping");
} else {
managedMemory =
ExecNodeUtil.getMemorySize(
config, ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY);
generatedOperator =
new HashAggCodeGenerator(
ctx,
planner.getRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
grouping,
auxGrouping,
isMerge,
isFinal)
.genWithKeys();
}
return ExecNodeUtil.createOneInputTransformation(
inputTransform,
getDesc(),
new CodeGenOperatorFactory<>(generatedOperator),
InternalTypeInfo.of(outputRowType),
inputTransform.getParallelism(),
managedMemory);
}
}
......@@ -20,10 +20,13 @@ package org.apache.flink.table.planner.plan.nodes.exec.utils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
......@@ -35,6 +38,15 @@ import java.util.stream.Collectors;
/** An Utility class that helps translating {@link ExecNode} to {@link Transformation}. */
public class ExecNodeUtil {
/**
* Return bytes size for given option in {@link TableConfig}.
*
* <p>TODO: This method can be removed once FLINK-20879 is finished.
*/
public static long getMemorySize(TableConfig tableConfig, ConfigOption<String> option) {
return MemorySize.parse(tableConfig.getConfiguration().getString(option)).getBytes();
}
/**
* Set memoryBytes to {@link
* Transformation#declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase, int)}.
......
......@@ -615,7 +615,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
agg.aggCalls(aggCallIndex)
case agg: BatchExecLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecHashAggregate if agg.isMerge =>
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
if (aggCallIndexInLocalAgg != null) {
......
......@@ -18,10 +18,22 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.MemorySize
import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, HashAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode}
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.planner.plan.utils.RelExplainUtil
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelDistribution.Type
......@@ -48,18 +60,17 @@ class BatchExecLocalHashAggregate(
grouping: Array[Int],
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
extends BatchExecHashAggregateBase(
extends BatchPhysicalHashAggregateBase(
cluster,
traitSet,
inputRel,
outputRowType,
inputRowType,
inputRowType,
grouping,
auxGrouping,
aggCallToAggFunction,
isMerge = false,
isFinal = false) {
isFinal = false)
with LegacyBatchExecNode[RowData] {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalHashAggregate(
......@@ -121,6 +132,47 @@ class BatchExecLocalHashAggregate(
//~ ExecNode methods -----------------------------------------------------------
override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val config = planner.getTableConfig
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ctx = CodeGeneratorContext(config)
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList)
var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
managedMemory = MemorySize.parse(config.getConfiguration.getString(
ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes
new HashAggCodeGenerator(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal
).genWithKeys()
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
managedMemory)
}
override def getInputEdges: util.List[ExecEdge] = {
if (grouping.length == 0) {
List(
......
......@@ -19,8 +19,10 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecHashAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase
import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil}
......@@ -40,7 +42,7 @@ import scala.collection.JavaConversions._
*
* @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/
class BatchExecHashAggregate(
class BatchPhysicalHashAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
......@@ -51,13 +53,11 @@ class BatchExecHashAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
isMerge: Boolean)
extends BatchExecHashAggregateBase(
extends BatchPhysicalHashAggregateBase(
cluster,
traitSet,
inputRel,
outputRowType,
inputRowType,
aggInputRowType,
grouping,
auxGrouping,
aggCallToAggFunction,
......@@ -65,7 +65,7 @@ class BatchExecHashAggregate(
isFinal = true) {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecHashAggregate(
new BatchPhysicalHashAggregate(
cluster,
traitSet,
inputs.get(0),
......@@ -141,10 +141,17 @@ class BatchExecHashAggregate(
Some(copy(newProvidedTraitSet, Seq(newInput)))
}
//~ ExecNode methods -----------------------------------------------------------
override def getInputEdges: util.List[ExecEdge] = List(
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())
override def translateToExecNode(): ExecNode[_] = {
new BatchExecHashAggregate(
grouping,
auxGrouping,
getAggCallList.toArray,
FlinkTypeFactory.toLogicalRowType(aggInputRowType),
isMerge,
true, // isFinal is always true
ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build(),
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
)
}
}
......@@ -17,23 +17,10 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.MemorySize
import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, HashAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.cost.FlinkCost._
import org.apache.flink.table.planner.plan.cost.FlinkCostFactory
import org.apache.flink.table.planner.plan.nodes.exec.LegacyBatchExecNode
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.RelNode
......@@ -47,13 +34,11 @@ import org.apache.calcite.util.Util
*
* @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/
abstract class BatchExecHashAggregateBase(
abstract class BatchPhysicalHashAggregateBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
outputRowType: RelDataType,
inputRowType: RelDataType,
aggInputRowType: RelDataType,
grouping: Array[Int],
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
......@@ -68,8 +53,7 @@ abstract class BatchExecHashAggregateBase(
auxGrouping,
aggCallToAggFunction,
isMerge,
isFinal)
with LegacyBatchExecNode[RowData] {
isFinal) {
override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = {
val numOfGroupKey = grouping.length
......@@ -97,47 +81,4 @@ abstract class BatchExecHashAggregateBase(
val costFactory = planner.getCostFactory.asInstanceOf[FlinkCostFactory]
costFactory.makeCost(rowCount, cpuCost, 0, 0, memCost)
}
//~ ExecNode methods -----------------------------------------------------------
override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val config = planner.getTableConfig
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ctx = CodeGeneratorContext(config)
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList)
var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
managedMemory = MemorySize.parse(config.getConfiguration.getString(
ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes
new HashAggCodeGenerator(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal
).genWithKeys()
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
managedMemory)
}
}
......@@ -414,7 +414,7 @@ object FlinkBatchRuleSets {
// expand
BatchPhysicalExpandRule.INSTANCE,
// group agg
BatchExecHashAggRule.INSTANCE,
BatchPhysicalHashAggRule.INSTANCE,
BatchExecSortAggRule.INSTANCE,
RemoveRedundantLocalSortAggRule.WITHOUT_SORT,
RemoveRedundantLocalSortAggRule.WITH_SORT,
......
......@@ -22,7 +22,7 @@ import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecHashAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalHashAggregate
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, OperatorType}
import org.apache.flink.table.planner.utils.TableConfigUtils.isOperatorDisabled
......@@ -37,7 +37,7 @@ import scala.collection.JavaConversions._
* Rule that matches [[FlinkLogicalAggregate]] which all aggregate function buffer are fix length,
* and converts it to
* {{{
* BatchExecHashAggregate (global)
* BatchPhysicalHashAggregate (global)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- BatchExecLocalHashAggregate (local)
* +- input of agg
......@@ -45,7 +45,7 @@ import scala.collection.JavaConversions._
* when all aggregate functions are mergeable
* and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or
* {{{
* BatchExecHashAggregate
* BatchPhysicalHashAggregate
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- input of agg
* }}}
......@@ -55,11 +55,11 @@ import scala.collection.JavaConversions._
* Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE,
* this rule will try to create two possibilities above, and chooses the best one based on cost.
*/
class BatchExecHashAggRule
class BatchPhysicalHashAggRule
extends RelOptRule(
operand(classOf[FlinkLogicalAggregate],
operand(classOf[RelNode], any)),
"BatchExecHashAggRule")
"BatchPhysicalHashAggRule")
with BatchPhysicalAggRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
......@@ -109,7 +109,7 @@ class BatchExecHashAggRule
aggCallToAggFunction,
isLocalHashAgg = true)
// create global BatchExecHashAggregate
// create global BatchPhysicalHashAggregate
val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet)
val globalDistributions = if (agg.getGroupCount != 0) {
val distributionFields = globalGroupSet.map(Integer.valueOf).toList
......@@ -133,7 +133,7 @@ class BatchExecHashAggRule
globalDistributions.foreach { globalDistribution =>
val requiredTraitSet = localHashAgg.getTraitSet.replace(globalDistribution)
val newLocalHashAgg = RelOptRule.convert(localHashAgg, requiredTraitSet)
val globalHashAgg = new BatchExecHashAggregate(
val globalHashAgg = new BatchPhysicalHashAggregate(
agg.getCluster,
aggProvidedTraitSet,
newLocalHashAgg,
......@@ -163,7 +163,7 @@ class BatchExecHashAggRule
.replace(FlinkConventions.BATCH_PHYSICAL)
.replace(requiredDistribution)
val newInput = RelOptRule.convert(input, requiredTraitSet)
val hashAgg = new BatchExecHashAggregate(
val hashAgg = new BatchPhysicalHashAggregate(
agg.getCluster,
aggProvidedTraitSet,
newInput,
......@@ -180,6 +180,6 @@ class BatchExecHashAggRule
}
}
object BatchExecHashAggRule {
val INSTANCE = new BatchExecHashAggRule
object BatchPhysicalHashAggRule {
val INSTANCE = new BatchPhysicalHashAggRule
}
......@@ -22,7 +22,7 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase, BatchPhysicalHashAggregate}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand}
......@@ -81,7 +81,7 @@ abstract class EnforceLocalAggRuleBase(
.replace(FlinkConventions.BATCH_PHYSICAL)
val isLocalHashAgg = completeAgg match {
case _: BatchExecHashAggregate => true
case _: BatchPhysicalHashAggregate => true
case _: BatchExecSortAggregate => false
case _ =>
throw new TableException(s"Unsupported aggregate: ${completeAgg.getClass.getSimpleName}")
......@@ -131,8 +131,8 @@ abstract class EnforceLocalAggRuleBase(
val aggInputRowType = completeAgg.getInput.getRowType
completeAgg match {
case _: BatchExecHashAggregate =>
new BatchExecHashAggregate(
case _: BatchPhysicalHashAggregate =>
new BatchPhysicalHashAggregate(
completeAgg.getCluster,
completeAgg.getTraitSet,
input,
......
......@@ -18,13 +18,13 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchPhysicalExchange, BatchPhysicalExpand}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalHashAggregate}
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.RelOptRuleCall
/**
* An [[EnforceLocalAggRuleBase]] that matches [[BatchExecHashAggregate]]
* An [[EnforceLocalAggRuleBase]] that matches [[BatchPhysicalHashAggregate]]
*
* for example: select count(*) from t group by rollup (a, b)
* The physical plan
......@@ -49,13 +49,13 @@ import org.apache.calcite.plan.RelOptRuleCall
* }}}
*/
class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase(
operand(classOf[BatchExecHashAggregate],
operand(classOf[BatchPhysicalHashAggregate],
operand(classOf[BatchPhysicalExchange],
operand(classOf[BatchPhysicalExpand], any))),
"EnforceLocalHashAggRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: BatchExecHashAggregate = call.rel(0)
val agg: BatchPhysicalHashAggregate = call.rel(0)
val expand: BatchPhysicalExpand = call.rel(2)
val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg)
......@@ -67,7 +67,7 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase(
}
override def onMatch(call: RelOptRuleCall): Unit = {
val agg: BatchExecHashAggregate = call.rel(0)
val agg: BatchPhysicalHashAggregate = call.rel(0)
val expand: BatchPhysicalExpand = call.rel(2)
val localAgg = createLocalAgg(agg, expand)
......
......@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchExecLocalHashAggregate}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalHashAggregate, BatchExecLocalHashAggregate}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
......@@ -30,16 +30,16 @@ import org.apache.calcite.rel.RelNode
* shuffle is removed. The rule could remove redundant localHashAggregate node.
*/
class RemoveRedundantLocalHashAggRule extends RelOptRule(
operand(classOf[BatchExecHashAggregate],
operand(classOf[BatchPhysicalHashAggregate],
operand(classOf[BatchExecLocalHashAggregate],
operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))),
"RemoveRedundantLocalHashAggRule") {
override def onMatch(call: RelOptRuleCall): Unit = {
val globalAgg: BatchExecHashAggregate = call.rel(0)
val globalAgg: BatchPhysicalHashAggregate = call.rel(0)
val localAgg: BatchExecLocalHashAggregate = call.rel(1)
val inputOfLocalAgg = localAgg.getInput
val newGlobalAgg = new BatchExecHashAggregate(
val newGlobalAgg = new BatchPhysicalHashAggregate(
globalAgg.getCluster,
globalAgg.getTraitSet,
inputOfLocalAgg,
......
......@@ -991,7 +991,7 @@ class FlinkRelMdHandlerTestBase {
val batchExchange1 = new BatchPhysicalExchange(
cluster, batchLocalAgg.getTraitSet.replace(hash0), batchLocalAgg, hash0)
val batchGlobalAgg = new BatchExecHashAggregate(
val batchGlobalAgg = new BatchPhysicalHashAggregate(
cluster,
batchPhysicalTraits,
batchExchange1,
......@@ -1005,7 +1005,7 @@ class FlinkRelMdHandlerTestBase {
val batchExchange2 = new BatchPhysicalExchange(cluster,
studentBatchScan.getTraitSet.replace(hash3), studentBatchScan, hash3)
val batchGlobalAggWithoutLocal = new BatchExecHashAggregate(
val batchGlobalAggWithoutLocal = new BatchPhysicalHashAggregate(
cluster,
batchPhysicalTraits,
batchExchange2,
......@@ -1127,7 +1127,7 @@ class FlinkRelMdHandlerTestBase {
.add("avg_score", doubleType)
.add("sum_score", doubleType)
.add("cnt", longType).build()
val batchGlobalAggWithAuxGroup = new BatchExecHashAggregate(
val batchGlobalAggWithAuxGroup = new BatchPhysicalHashAggregate(
cluster,
batchPhysicalTraits,
batchExchange,
......@@ -1141,7 +1141,7 @@ class FlinkRelMdHandlerTestBase {
val batchExchange2 = new BatchPhysicalExchange(cluster,
studentBatchScan.getTraitSet.replace(hash0), studentBatchScan, hash0)
val batchGlobalAggWithoutLocalWithAuxGroup = new BatchExecHashAggregate(
val batchGlobalAggWithoutLocalWithAuxGroup = new BatchPhysicalHashAggregate(
cluster,
batchPhysicalTraits,
batchExchange2,
......
......@@ -42,9 +42,9 @@ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase {
// remove the original BatchExecHashAggRule and add BatchExecHashAggRuleForOnePhase
// to let the physical phase generate one phase aggregate
program.getFlinkRuleSetProgram(FlinkBatchProgram.PHYSICAL)
.get.remove(RuleSets.ofList(BatchExecHashAggRule.INSTANCE))
.get.remove(RuleSets.ofList(BatchPhysicalHashAggRule.INSTANCE))
program.getFlinkRuleSetProgram(FlinkBatchProgram.PHYSICAL)
.get.add(RuleSets.ofList(BatchExecHashAggRuleForOnePhase.INSTANCE))
.get.add(RuleSets.ofList(BatchPhysicalHashAggRuleForOnePhase.INSTANCE))
var calciteConfig = TableConfigUtils.getCalciteConfig(util.tableEnv.getConfig)
calciteConfig = CalciteConfig.createBuilder(calciteConfig)
......@@ -64,7 +64,7 @@ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase {
* value, and only enable one phase aggregate.
* This rule only used for test.
*/
class BatchExecHashAggRuleForOnePhase extends BatchExecHashAggRule {
class BatchPhysicalHashAggRuleForOnePhase extends BatchPhysicalHashAggRule {
override protected def isTwoPhaseAggWorkable(
aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = false
......@@ -72,6 +72,6 @@ class BatchExecHashAggRuleForOnePhase extends BatchExecHashAggRule {
aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = true
}
object BatchExecHashAggRuleForOnePhase {
val INSTANCE = new BatchExecHashAggRuleForOnePhase
object BatchPhysicalHashAggRuleForOnePhase {
val INSTANCE = new BatchPhysicalHashAggRuleForOnePhase
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册