提交 53a888c5 编写于 作者: G godfreyhe

[FLINK-20737][table-planner-blink] Use RowType instead of RelDataType when building aggregate info

This closes #14478
上级 5351203f
...@@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch; ...@@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch;
import org.apache.flink.table.api.TableException; import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind; import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate; import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate;
...@@ -106,7 +107,9 @@ public class BatchExecPythonAggregateRule extends ConverterRule { ...@@ -106,7 +107,9 @@ public class BatchExecPythonAggregateRule extends ConverterRule {
Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions = Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions =
AggregateUtil.transformToBatchAggregateFunctions( AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType(), null); FlinkTypeFactory.toLogicalRowType(input.getRowType()),
aggCallsWithoutAuxGroupCalls,
null);
UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();
RelTraitSet requiredTraitSet = RelTraitSet requiredTraitSet =
......
...@@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableException; ...@@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind; import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkRelFactories; import org.apache.flink.table.planner.calcite.FlinkRelFactories;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.logical.LogicalWindow; import org.apache.flink.table.planner.plan.logical.LogicalWindow;
import org.apache.flink.table.planner.plan.logical.SessionGroupWindow; import org.apache.flink.table.planner.plan.logical.SessionGroupWindow;
import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow; import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow;
...@@ -121,7 +122,9 @@ public class BatchExecPythonWindowAggregateRule extends RelOptRule { ...@@ -121,7 +122,9 @@ public class BatchExecPythonWindowAggregateRule extends RelOptRule {
Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions = Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions =
AggregateUtil.transformToBatchAggregateFunctions( AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType(), null); FlinkTypeFactory.toLogicalRowType(input.getRowType()),
aggCallsWithoutAuxGroupCalls,
null);
UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();
int inputTimeFieldIndex = int inputTimeFieldIndex =
......
...@@ -682,8 +682,8 @@ class MatchCodeGenerator( ...@@ -682,8 +682,8 @@ class MatchCodeGenerator(
matchAgg.inputExprs.indices.map(i => s"TMP$i")) matchAgg.inputExprs.indices.map(i => s"TMP$i"))
val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRelType),
aggCalls, aggCalls,
inputRelType,
needRetraction, needRetraction,
needInputCount = false, needInputCount = false,
isStateBackendDataViews = false, isStateBackendDataViews = false,
......
...@@ -114,7 +114,7 @@ abstract class BatchExecHashAggregateBase( ...@@ -114,7 +114,7 @@ abstract class BatchExecHashAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType) FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
var managedMemory: Long = 0L var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) { val generatedOperator = if (grouping.isEmpty) {
......
...@@ -113,7 +113,7 @@ abstract class BatchExecHashWindowAggregateBase( ...@@ -113,7 +113,7 @@ abstract class BatchExecHashWindowAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType) FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val groupBufferLimitSize = config.getConfiguration.getInteger( val groupBufferLimitSize = config.getConfiguration.getInteger(
ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT) ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT)
......
...@@ -140,10 +140,10 @@ class BatchExecOverAggregate( ...@@ -140,10 +140,10 @@ class BatchExecOverAggregate(
//operator needn't cache data //operator needn't cache data
val aggHandlers = modeToGroupToAggCallToAggFunction.map { case (_, _, aggCallToAggFunction) => val aggHandlers = modeToGroupToAggCallToAggFunction.map { case (_, _, aggCallToAggFunction) =>
val aggInfoList = transformToBatchAggregateInfoList( val aggInfoList = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1),
// use aggInputType which considers constants as input instead of inputType // use aggInputType which considers constants as input instead of inputType
inputTypeWithConstants, FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
orderKeyIndices) aggCallToAggFunction.map(_._1),
orderKeyIndexes = orderKeyIndices)
val codeGenCtx = CodeGeneratorContext(config) val codeGenCtx = CodeGeneratorContext(config)
val generator = new AggsHandlerCodeGenerator( val generator = new AggsHandlerCodeGenerator(
codeGenCtx, codeGenCtx,
...@@ -191,10 +191,10 @@ class BatchExecOverAggregate( ...@@ -191,10 +191,10 @@ class BatchExecOverAggregate(
//lies on the offset of the window frame. //lies on the offset of the window frame.
aggCallToAggFunction.map { case (aggCall, _) => aggCallToAggFunction.map { case (aggCall, _) =>
val aggInfoList = transformToBatchAggregateInfoList( val aggInfoList = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
Seq(aggCall), Seq(aggCall),
inputTypeWithConstants, Array[Boolean](true), /* needRetraction = true, See LeadLagAggFunction */
orderKeyIndices, orderKeyIndexes = orderKeyIndices)
Array[Boolean](true) /* needRetraction = true, See LeadLagAggFunction */)
val generator = new AggsHandlerCodeGenerator( val generator = new AggsHandlerCodeGenerator(
CodeGeneratorContext(config), CodeGeneratorContext(config),
...@@ -263,10 +263,10 @@ class BatchExecOverAggregate( ...@@ -263,10 +263,10 @@ class BatchExecOverAggregate(
case _ => case _ =>
val aggInfoList = transformToBatchAggregateInfoList( val aggInfoList = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1),
//use aggInputType which considers constants as input instead of inputSchema.relDataType //use aggInputType which considers constants as input instead of inputSchema.relDataType
inputTypeWithConstants, FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
orderKeyIndices) aggCallToAggFunction.map(_._1),
orderKeyIndexes = orderKeyIndices)
val codeGenCtx = CodeGeneratorContext(config) val codeGenCtx = CodeGeneratorContext(config)
val generator = new AggsHandlerCodeGenerator( val generator = new AggsHandlerCodeGenerator(
codeGenCtx, codeGenCtx,
......
...@@ -95,7 +95,7 @@ abstract class BatchExecSortAggregateBase( ...@@ -95,7 +95,7 @@ abstract class BatchExecSortAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType) FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val generatedOperator = if (grouping.isEmpty) { val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys( AggWithoutKeysCodeGenerator.genWithoutKeys(
......
...@@ -101,7 +101,7 @@ abstract class BatchExecSortWindowAggregateBase( ...@@ -101,7 +101,7 @@ abstract class BatchExecSortWindowAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType) FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val groupBufferLimitSize = planner.getTableConfig.getConfiguration.getInteger( val groupBufferLimitSize = planner.getTableConfig.getConfiguration.getInteger(
ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT) ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT)
......
...@@ -63,8 +63,8 @@ class StreamExecGroupAggregate( ...@@ -63,8 +63,8 @@ class StreamExecGroupAggregate(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this, this,
aggCalls, grouping.length,
grouping) aggCalls)
override def requireWatermark: Boolean = false override def requireWatermark: Boolean = false
......
...@@ -42,8 +42,8 @@ abstract class StreamExecGroupTableAggregateBase( ...@@ -42,8 +42,8 @@ abstract class StreamExecGroupTableAggregateBase(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this, this,
aggCalls, grouping.length,
grouping) aggCalls)
override def requireWatermark: Boolean = false override def requireWatermark: Boolean = false
......
...@@ -137,8 +137,8 @@ abstract class StreamExecGroupWindowAggregateBase( ...@@ -137,8 +137,8 @@ abstract class StreamExecGroupWindowAggregateBase(
val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this)
val aggInfoList = transformToStreamAggregateInfoList( val aggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowType),
aggCalls, aggCalls,
inputRowType,
Array.fill(aggCalls.size)(needRetraction), Array.fill(aggCalls.size)(needRetraction),
needInputCount = needRetraction, needInputCount = needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
......
...@@ -243,9 +243,9 @@ class StreamExecOverAggregate( ...@@ -243,9 +243,9 @@ class StreamExecOverAggregate(
val needRetraction = false val needRetraction = false
val aggInfoList = transformToStreamAggregateInfoList( val aggInfoList = transformToStreamAggregateInfoList(
aggregateCalls,
// use aggInputType which considers constants as input instead of inputSchema.relDataType // use aggInputType which considers constants as input instead of inputSchema.relDataType
aggInputType, FlinkTypeFactory.toLogicalRowType(aggInputType),
aggregateCalls,
Array.fill(aggregateCalls.size)(needRetraction), Array.fill(aggregateCalls.size)(needRetraction),
needInputCount = needRetraction, needInputCount = needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
...@@ -322,9 +322,9 @@ class StreamExecOverAggregate( ...@@ -322,9 +322,9 @@ class StreamExecOverAggregate(
val needRetraction = true val needRetraction = true
val aggInfoList = transformToStreamAggregateInfoList( val aggInfoList = transformToStreamAggregateInfoList(
aggregateCalls,
// use aggInputType which considers constants as input instead of inputSchema.relDataType // use aggInputType which considers constants as input instead of inputSchema.relDataType
aggInputType, FlinkTypeFactory.toLogicalRowType(aggInputType),
aggregateCalls,
Array.fill(aggregateCalls.size)(needRetraction), Array.fill(aggregateCalls.size)(needRetraction),
needInputCount = needRetraction, needInputCount = needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
......
...@@ -59,8 +59,8 @@ class StreamExecPythonGroupAggregate( ...@@ -59,8 +59,8 @@ class StreamExecPythonGroupAggregate(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this, this,
aggCalls, grouping.length,
grouping) aggCalls)
override def requireWatermark: Boolean = false override def requireWatermark: Boolean = false
......
...@@ -154,7 +154,7 @@ trait BatchExecAggRuleBase { ...@@ -154,7 +154,7 @@ trait BatchExecAggRuleBase {
protected def isAggBufferFixedLength(agg: Aggregate): Boolean = { protected def isAggBufferFixedLength(agg: Aggregate): Boolean = {
val (_, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (_, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions( val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, agg.getInput.getRowType) FlinkTypeFactory.toLogicalRowType(agg.getInput.getRowType), aggCallsWithoutAuxGroupCalls)
isAggBufferFixedLength(aggBufferTypes.map(_.map(fromDataTypeToLogicalType))) isAggBufferFixedLength(aggBufferTypes.map(_.map(fromDataTypeToLogicalType)))
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.config.OptimizerConfigOptions import org.apache.flink.table.api.config.OptimizerConfigOptions
import org.apache.flink.table.planner.calcite.FlinkContext import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
...@@ -87,7 +87,7 @@ class BatchExecHashAggRule ...@@ -87,7 +87,7 @@ class BatchExecHashAggRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions( val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, inputRowType) FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls)
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions) val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
......
...@@ -94,7 +94,9 @@ class BatchExecOverAggregateRule ...@@ -94,7 +94,9 @@ class BatchExecOverAggregateRule
val groupToAggCallToAggFunction = groupBuffer.map { group => val groupToAggCallToAggFunction = groupBuffer.map { group =>
val aggregateCalls = group.getAggregateCalls(logicWindow) val aggregateCalls = group.getAggregateCalls(logicWindow)
val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
aggregateCalls, inputTypeWithConstants, orderKeyIndexes) FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
aggregateCalls,
orderKeyIndexes)
val aggCallToAggFunction = aggregateCalls.zip(aggregates) val aggCallToAggFunction = aggregateCalls.zip(aggregates)
(group, aggCallToAggFunction) (group, aggCallToAggFunction)
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.config.OptimizerConfigOptions import org.apache.flink.table.api.config.OptimizerConfigOptions
import org.apache.flink.table.planner.calcite.FlinkContext import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
...@@ -80,7 +80,7 @@ class BatchExecSortAggRule ...@@ -80,7 +80,7 @@ class BatchExecSortAggRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions( val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, inputRowType) FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls)
val groupSet = agg.getGroupSet.toArray val groupSet = agg.getGroupSet.toArray
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions) val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
// TODO aggregate include projection now, so do not provide new trait will be safe // TODO aggregate include projection now, so do not provide new trait will be safe
......
...@@ -101,7 +101,7 @@ class BatchExecWindowAggregateRule ...@@ -101,7 +101,7 @@ class BatchExecWindowAggregateRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg) val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( val (_, aggBufferTypes, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType) FlinkTypeFactory.toLogicalRowType(input.getRowType), aggCallsWithoutAuxGroupCalls)
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggregates) val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggregates)
val internalAggBufferTypes = aggBufferTypes.map(_.map(fromDataTypeToLogicalType)) val internalAggBufferTypes = aggBufferTypes.map(_.map(fromDataTypeToLogicalType))
val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
package org.apache.flink.table.planner.plan.rules.physical.batch package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.TableException import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExpand, BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand}
...@@ -75,7 +76,7 @@ abstract class EnforceLocalAggRuleBase( ...@@ -75,7 +76,7 @@ abstract class EnforceLocalAggRuleBase(
val aggCallToAggFunction = completeAgg.getAggCallToAggFunction val aggCallToAggFunction = completeAgg.getAggCallToAggFunction
val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions( val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions(
aggCalls, inputRowType) FlinkTypeFactory.toLogicalRowType(inputRowType), aggCalls)
val traitSet = cluster.getPlanner val traitSet = cluster.getPlanner
.emptyTraitSet .emptyTraitSet
......
...@@ -132,9 +132,9 @@ class IncrementalAggregateRule ...@@ -132,9 +132,9 @@ class IncrementalAggregateRule
} else { } else {
// an additional count1 is inserted, need to adapt the global agg // an additional count1 is inserted, need to adapt the global agg
val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
aggCalls,
// the final agg input is partial agg // the final agg input is partial agg
partialGlobalAgg.getRowType, FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType),
aggCalls,
// all the aggs do not need retraction // all the aggs do not need retraction
Array.fill(aggCalls.length)(false), Array.fill(aggCalls.length)(false),
// also do not need count* // also do not need count*
...@@ -142,9 +142,9 @@ class IncrementalAggregateRule ...@@ -142,9 +142,9 @@ class IncrementalAggregateRule
// the local agg is not works on state // the local agg is not works on state
isStateBackendDataViews = false) isStateBackendDataViews = false)
val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
aggCalls,
// the final agg input is partial agg // the final agg input is partial agg
partialGlobalAgg.getRowType, FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType),
aggCalls,
// all the aggs do not need retraction // all the aggs do not need retraction
Array.fill(aggCalls.length)(false), Array.fill(aggCalls.length)(false),
// also do not need count* // also do not need count*
......
...@@ -66,11 +66,11 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( ...@@ -66,11 +66,11 @@ class TwoStageOptimizedAggregateRule extends RelOptRule(
val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(agg) val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = AggregateUtil.getNeedRetractions( val needRetractionArray = AggregateUtil.getNeedRetractions(
agg.grouping.length, needRetraction, monotonicity, agg.aggCalls) agg.grouping.length, agg.aggCalls, needRetraction, monotonicity)
val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType( agg.getInput.getRowType),
agg.aggCalls, agg.aggCalls,
agg.getInput.getRowType,
needRetractionArray, needRetractionArray,
needRetraction, needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
...@@ -98,18 +98,18 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( ...@@ -98,18 +98,18 @@ class TwoStageOptimizedAggregateRule extends RelOptRule(
val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(agg) val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = AggregateUtil.getNeedRetractions( val needRetractionArray = AggregateUtil.getNeedRetractions(
agg.grouping.length, needRetraction, monotonicity, agg.aggCalls) agg.grouping.length, agg.aggCalls, needRetraction, monotonicity)
val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(realInput.getRowType),
agg.aggCalls, agg.aggCalls,
realInput.getRowType,
needRetractionArray, needRetractionArray,
needRetraction, needRetraction,
isStateBackendDataViews = false) isStateBackendDataViews = false)
val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(realInput.getRowType),
agg.aggCalls, agg.aggCalls,
realInput.getRowType,
needRetractionArray, needRetractionArray,
needRetraction, needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
......
...@@ -19,24 +19,17 @@ package org.apache.flink.table.planner.plan.utils ...@@ -19,24 +19,17 @@ package org.apache.flink.table.planner.plan.utils
import org.apache.flink.table.api.TableException import org.apache.flink.table.api.TableException
import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.aggfunctions.FirstValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.FirstValueWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.IncrSumAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.IncrSumAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.IncrSumWithRetractAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.IncrSumWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.LastValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.LastValueWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.SingleValueAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.SingleValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction._ import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions._ import org.apache.flink.table.planner.functions.aggfunctions._
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction} import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction}
import org.apache.flink.table.planner.functions.utils.AggSqlFunction import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.runtime.typeutils.DecimalDataTypeInfo
import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._ import org.apache.flink.table.types.logical._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.sql.fun._ import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction} import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction}
...@@ -49,14 +42,14 @@ import scala.collection.JavaConversions._ ...@@ -49,14 +42,14 @@ import scala.collection.JavaConversions._
* The class of agg function factory which is used to create AggregateFunction or * The class of agg function factory which is used to create AggregateFunction or
* DeclarativeAggregateFunction from Calcite AggregateCall * DeclarativeAggregateFunction from Calcite AggregateCall
* *
* @param inputType the input rel data type * @param inputRowType the input's output RowType
* @param orderKeyIdx the indexes of order key (null when is not over agg) * @param orderKeyIndexes the indexes of order key (null when is not over agg)
* @param needRetraction true if need retraction * @param aggCallNeedRetractions true if need retraction
*/ */
class AggFunctionFactory( class AggFunctionFactory(
inputType: RelDataType, inputRowType: RowType,
orderKeyIdx: Array[Int], orderKeyIndexes: Array[Int],
needRetraction: Array[Boolean]) { aggCallNeedRetractions: Array[Boolean]) {
/** /**
* The entry point to create an aggregate function from the given AggregateCall * The entry point to create an aggregate function from the given AggregateCall
...@@ -64,8 +57,7 @@ class AggFunctionFactory( ...@@ -64,8 +57,7 @@ class AggFunctionFactory(
def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = { def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = {
val argTypes: Array[LogicalType] = call.getArgList val argTypes: Array[LogicalType] = call.getArgList
.map(inputType.getFieldList.get(_).getType) .map(inputRowType.getChildren.get(_))
.map(FlinkTypeFactory.toLogicalType)
.toArray .toArray
call.getAggregation match { call.getAggregation match {
...@@ -165,7 +157,7 @@ class AggFunctionFactory( ...@@ -165,7 +157,7 @@ class AggFunctionFactory(
private def createSumAggFunction( private def createSumAggFunction(
argTypes: Array[LogicalType], argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = { index: Int): UserDefinedFunction = {
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
argTypes(0).getTypeRoot match { argTypes(0).getTypeRoot match {
case TINYINT => case TINYINT =>
new ByteSumWithRetractAggFunction new ByteSumWithRetractAggFunction
...@@ -236,7 +228,7 @@ class AggFunctionFactory( ...@@ -236,7 +228,7 @@ class AggFunctionFactory(
private def createIncrSumAggFunction( private def createIncrSumAggFunction(
argTypes: Array[LogicalType], argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = { index: Int): UserDefinedFunction = {
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
argTypes(0).getTypeRoot match { argTypes(0).getTypeRoot match {
case TINYINT => case TINYINT =>
new ByteIncrSumWithRetractAggFunction new ByteIncrSumWithRetractAggFunction
...@@ -286,7 +278,7 @@ class AggFunctionFactory( ...@@ -286,7 +278,7 @@ class AggFunctionFactory(
index: Int) index: Int)
: UserDefinedFunction = { : UserDefinedFunction = {
val valueType = argTypes(0) val valueType = argTypes(0)
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match { valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE => TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE =>
...@@ -370,7 +362,7 @@ class AggFunctionFactory( ...@@ -370,7 +362,7 @@ class AggFunctionFactory(
index: Int) index: Int)
: UserDefinedFunction = { : UserDefinedFunction = {
val valueType = argTypes(0) val valueType = argTypes(0)
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match { valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL | case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE => TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE =>
...@@ -460,16 +452,12 @@ class AggFunctionFactory( ...@@ -460,16 +452,12 @@ class AggFunctionFactory(
} }
private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIdx val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
new RankAggFunction(argTypes) new RankAggFunction(argTypes)
} }
private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIdx val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
new DenseRankAggFunction(argTypes) new DenseRankAggFunction(argTypes)
} }
...@@ -478,7 +466,7 @@ class AggFunctionFactory( ...@@ -478,7 +466,7 @@ class AggFunctionFactory(
index: Int) index: Int)
: UserDefinedFunction = { : UserDefinedFunction = {
val valueType = argTypes(0) val valueType = argTypes(0)
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match { valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL => case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL =>
new FirstValueWithRetractAggFunction(valueType) new FirstValueWithRetractAggFunction(valueType)
...@@ -502,7 +490,7 @@ class AggFunctionFactory( ...@@ -502,7 +490,7 @@ class AggFunctionFactory(
index: Int) index: Int)
: UserDefinedFunction = { : UserDefinedFunction = {
val valueType = argTypes(0) val valueType = argTypes(0)
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match { valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL => case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL =>
new LastValueWithRetractAggFunction(valueType) new LastValueWithRetractAggFunction(valueType)
...@@ -524,7 +512,7 @@ class AggFunctionFactory( ...@@ -524,7 +512,7 @@ class AggFunctionFactory(
private def createListAggFunction( private def createListAggFunction(
argTypes: Array[LogicalType], argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = { index: Int): UserDefinedFunction = {
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
new ListAggWithRetractAggFunction new ListAggWithRetractAggFunction
} else { } else {
new ListAggFunction(1) new ListAggFunction(1)
...@@ -534,7 +522,7 @@ class AggFunctionFactory( ...@@ -534,7 +522,7 @@ class AggFunctionFactory(
private def createListAggWsFunction( private def createListAggWsFunction(
argTypes: Array[LogicalType], argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = { index: Int): UserDefinedFunction = {
if (needRetraction(index)) { if (aggCallNeedRetractions(index)) {
new ListAggWsWithRetractAggFunction new ListAggWsWithRetractAggFunction
} else { } else {
new ListAggFunction(2) new ListAggFunction(2)
......
...@@ -151,12 +151,12 @@ object AggregateUtil extends Enumeration { ...@@ -151,12 +151,12 @@ object AggregateUtil extends Enumeration {
def getOutputIndexToAggCallIndexMap( def getOutputIndexToAggCallIndexMap(
aggregateCalls: Seq[AggregateCall], aggregateCalls: Seq[AggregateCall],
inputType: RelDataType, inputType: RelDataType,
orderKeyIdx: Array[Int] = null): util.Map[Integer, Integer] = { orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = {
val aggInfos = transformToAggregateInfoList( val aggInfos = transformToAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputType),
aggregateCalls, aggregateCalls,
inputType,
orderKeyIdx,
Array.fill(aggregateCalls.size)(false), Array.fill(aggregateCalls.size)(false),
orderKeyIndexes,
needInputCount = false, needInputCount = false,
isStateBackedDataViews = false, isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos needDistinctInfo = false).aggInfos
...@@ -176,10 +176,10 @@ object AggregateUtil extends Enumeration { ...@@ -176,10 +176,10 @@ object AggregateUtil extends Enumeration {
} }
def deriveAggregateInfoList( def deriveAggregateInfoList(
aggNode: StreamPhysicalRel, agg: StreamPhysicalRel,
aggCalls: Seq[AggregateCall], groupCount: Int,
grouping: Array[Int]): AggregateInfoList = { aggCalls: Seq[AggregateCall]): AggregateInfoList = {
val input = aggNode.getInput(0) val input = agg.getInput(0)
// need to call `retract()` if input contains update or delete // need to call `retract()` if input contains update or delete
val modifyKindSetTrait = input.getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE) val modifyKindSetTrait = input.getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE)
val needRetraction = if (modifyKindSetTrait == null) { val needRetraction = if (modifyKindSetTrait == null) {
...@@ -188,29 +188,28 @@ object AggregateUtil extends Enumeration { ...@@ -188,29 +188,28 @@ object AggregateUtil extends Enumeration {
} else { } else {
!modifyKindSetTrait.modifyKindSet.isInsertOnly !modifyKindSetTrait.modifyKindSet.isInsertOnly
} }
val fmq = FlinkRelMetadataQuery.reuseOrCreate(aggNode.getCluster.getMetadataQuery) val fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(aggNode) val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = AggregateUtil.getNeedRetractions( val needRetractionArray = getNeedRetractions(groupCount, aggCalls, needRetraction, monotonicity)
grouping.length, needRetraction, monotonicity, aggCalls) transformToStreamAggregateInfoList(
AggregateUtil.transformToStreamAggregateInfoList( FlinkTypeFactory.toLogicalRowType(input.getRowType),
aggCalls, aggCalls,
input.getRowType,
needRetractionArray, needRetractionArray,
needInputCount = needRetraction, needInputCount = needRetraction,
isStateBackendDataViews = true) isStateBackendDataViews = true)
} }
def transformToBatchAggregateFunctions( def transformToBatchAggregateFunctions(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall], aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType, orderKeyIndexes: Array[Int] = null)
orderKeyIdx: Array[Int] = null)
: (Array[Array[Int]], Array[Array[DataType]], Array[UserDefinedFunction]) = { : (Array[Array[Int]], Array[Array[DataType]], Array[UserDefinedFunction]) = {
val aggInfos = transformToAggregateInfoList( val aggInfos = transformToAggregateInfoList(
aggregateCalls,
inputRowType, inputRowType,
orderKeyIdx, aggregateCalls,
Array.fill(aggregateCalls.size)(false), Array.fill(aggregateCalls.size)(false),
orderKeyIndexes,
needInputCount = false, needInputCount = false,
isStateBackedDataViews = false, isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos needDistinctInfo = false).aggInfos
...@@ -223,39 +222,39 @@ object AggregateUtil extends Enumeration { ...@@ -223,39 +222,39 @@ object AggregateUtil extends Enumeration {
} }
def transformToBatchAggregateInfoList( def transformToBatchAggregateInfoList(
aggregateCalls: Seq[AggregateCall], inputRowType: RowType,
inputRowType: RelDataType, aggCalls: Seq[AggregateCall],
orderKeyIdx: Array[Int] = null, aggCallNeedRetractions: Array[Boolean] = null,
needRetractions: Array[Boolean] = null): AggregateInfoList = { orderKeyIndexes: Array[Int] = null): AggregateInfoList = {
val needRetractionArray = if (needRetractions == null) { val finalAggCallNeedRetractions = if (aggCallNeedRetractions == null) {
Array.fill(aggregateCalls.size)(false) Array.fill(aggCalls.size)(false)
} else { } else {
needRetractions aggCallNeedRetractions
} }
transformToAggregateInfoList( transformToAggregateInfoList(
aggregateCalls,
inputRowType, inputRowType,
orderKeyIdx, aggCalls,
needRetractionArray, finalAggCallNeedRetractions,
orderKeyIndexes,
needInputCount = false, needInputCount = false,
isStateBackedDataViews = false, isStateBackedDataViews = false,
needDistinctInfo = false) needDistinctInfo = false)
} }
def transformToStreamAggregateInfoList( def transformToStreamAggregateInfoList(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall], aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType, aggCallNeedRetractions: Array[Boolean],
needRetraction: Array[Boolean],
needInputCount: Boolean, needInputCount: Boolean,
isStateBackendDataViews: Boolean, isStateBackendDataViews: Boolean,
needDistinctInfo: Boolean = true): AggregateInfoList = { needDistinctInfo: Boolean = true): AggregateInfoList = {
transformToAggregateInfoList( transformToAggregateInfoList(
aggregateCalls,
inputRowType, inputRowType,
orderKeyIdx = null, aggregateCalls,
needRetraction ++ Array(needInputCount), // for additional count(*) aggCallNeedRetractions ++ Array(needInputCount), // for additional count(*)
orderKeyIndexes = null,
needInputCount, needInputCount,
isStateBackendDataViews, isStateBackendDataViews,
needDistinctInfo) needDistinctInfo)
...@@ -264,10 +263,10 @@ object AggregateUtil extends Enumeration { ...@@ -264,10 +263,10 @@ object AggregateUtil extends Enumeration {
/** /**
* Transforms calcite aggregate calls to AggregateInfos. * Transforms calcite aggregate calls to AggregateInfos.
* *
* @param inputRowType the input's output RowType
* @param aggregateCalls the calcite aggregate calls * @param aggregateCalls the calcite aggregate calls
* @param inputRowType the input rel data type * @param aggCallNeedRetractions whether the aggregate function need retract method
* @param orderKeyIdx the index of order by field in the input, null if not over agg * @param orderKeyIndexes the index of order by field in the input, null if not over agg
* @param needRetraction whether the aggregate function need retract method
* @param needInputCount whether need to calculate the input counts, which is used in * @param needInputCount whether need to calculate the input counts, which is used in
* aggregation with retraction input.If needed, * aggregation with retraction input.If needed,
* insert a count(1) aggregate into the agg list. * insert a count(1) aggregate into the agg list.
...@@ -275,10 +274,10 @@ object AggregateUtil extends Enumeration { ...@@ -275,10 +274,10 @@ object AggregateUtil extends Enumeration {
* @param needDistinctInfo whether need to extract distinct information * @param needDistinctInfo whether need to extract distinct information
*/ */
private def transformToAggregateInfoList( private def transformToAggregateInfoList(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall], aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType, aggCallNeedRetractions: Array[Boolean],
orderKeyIdx: Array[Int], orderKeyIndexes: Array[Int],
needRetraction: Array[Boolean],
needInputCount: Boolean, needInputCount: Boolean,
isStateBackedDataViews: Boolean, isStateBackedDataViews: Boolean,
needDistinctInfo: Boolean): AggregateInfoList = { needDistinctInfo: Boolean): AggregateInfoList = {
...@@ -301,12 +300,12 @@ object AggregateUtil extends Enumeration { ...@@ -301,12 +300,12 @@ object AggregateUtil extends Enumeration {
// Step-3: // Step-3:
// create aggregate information // create aggregate information
val factory = new AggFunctionFactory(inputRowType, orderKeyIdx, needRetraction) val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions)
val aggInfos = newAggCalls val aggInfos = newAggCalls
.zipWithIndex .zipWithIndex
.map { case (call, index) => .map { case (call, index) =>
val argIndexes = call.getAggregation match { val argIndexes = call.getAggregation match {
case _: SqlRankFunction => orderKeyIdx case _: SqlRankFunction => orderKeyIndexes
case _ => call.getArgList.map(_.intValue()).toArray case _ => call.getArgList.map(_.intValue()).toArray
} }
transformToAggregateInfo( transformToAggregateInfo(
...@@ -316,14 +315,14 @@ object AggregateUtil extends Enumeration { ...@@ -316,14 +315,14 @@ object AggregateUtil extends Enumeration {
argIndexes, argIndexes,
factory.createAggFunction(call, index), factory.createAggFunction(call, index),
isStateBackedDataViews, isStateBackedDataViews,
needRetraction(index)) aggCallNeedRetractions(index))
} }
AggregateInfoList(aggInfos.toArray, indexOfCountStar, countStarInserted, distinctInfos) AggregateInfoList(aggInfos.toArray, indexOfCountStar, countStarInserted, distinctInfos)
} }
private def transformToAggregateInfo( private def transformToAggregateInfo(
inputRowRelDataType: RelDataType, inputRowType: RowType,
call: AggregateCall, call: AggregateCall,
index: Int, index: Int,
argIndexes: Array[Int], argIndexes: Array[Int],
...@@ -334,7 +333,7 @@ object AggregateUtil extends Enumeration { ...@@ -334,7 +333,7 @@ object AggregateUtil extends Enumeration {
case _: BridgingSqlAggFunction => case _: BridgingSqlAggFunction =>
createAggregateInfoFromBridgingFunction( createAggregateInfoFromBridgingFunction(
inputRowRelDataType, inputRowType,
call, call,
index, index,
argIndexes, argIndexes,
...@@ -344,7 +343,7 @@ object AggregateUtil extends Enumeration { ...@@ -344,7 +343,7 @@ object AggregateUtil extends Enumeration {
case _: AggSqlFunction => case _: AggSqlFunction =>
createAggregateInfoFromLegacyFunction( createAggregateInfoFromLegacyFunction(
inputRowRelDataType, inputRowType,
call, call,
index, index,
argIndexes, argIndexes,
...@@ -363,7 +362,7 @@ object AggregateUtil extends Enumeration { ...@@ -363,7 +362,7 @@ object AggregateUtil extends Enumeration {
} }
private def createAggregateInfoFromBridgingFunction( private def createAggregateInfoFromBridgingFunction(
inputRowRelDataType: RelDataType, inputRowType: RowType,
call: AggregateCall, call: AggregateCall,
index: Int, index: Int,
argIndexes: Array[Int], argIndexes: Array[Int],
...@@ -387,7 +386,7 @@ object AggregateUtil extends Enumeration { ...@@ -387,7 +386,7 @@ object AggregateUtil extends Enumeration {
function.getTypeFactory, function.getTypeFactory,
function, function,
SqlTypeUtil.projectTypes( SqlTypeUtil.projectTypes(
inputRowRelDataType, FlinkTypeFactory.INSTANCE.buildRelNodeRowType(inputRowType),
argIndexes.map(Int.box).toList), argIndexes.map(Int.box).toList),
0, 0,
false)) false))
...@@ -490,7 +489,7 @@ object AggregateUtil extends Enumeration { ...@@ -490,7 +489,7 @@ object AggregateUtil extends Enumeration {
} }
private def createAggregateInfoFromLegacyFunction( private def createAggregateInfoFromLegacyFunction(
inputRowRelDataType: RelDataType, inputRowType: RowType,
call: AggregateCall, call: AggregateCall,
index: Int, index: Int,
argIndexes: Array[Int], argIndexes: Array[Int],
...@@ -507,8 +506,7 @@ object AggregateUtil extends Enumeration { ...@@ -507,8 +506,7 @@ object AggregateUtil extends Enumeration {
} }
val externalAccType = getAccumulatorTypeOfAggregateFunction(a, implicitAccType) val externalAccType = getAccumulatorTypeOfAggregateFunction(a, implicitAccType)
val argTypes = call.getArgList val argTypes = call.getArgList
.map(idx => inputRowRelDataType.getFieldList.get(idx).getType) .map(idx => inputRowType.getChildren.get(idx))
.map(FlinkTypeFactory.toLogicalType)
val externalArgTypes: Array[DataType] = getAggUserDefinedInputTypes( val externalArgTypes: Array[DataType] = getAggUserDefinedInputTypes(
a, a,
externalAccType, externalAccType,
...@@ -605,7 +603,7 @@ object AggregateUtil extends Enumeration { ...@@ -605,7 +603,7 @@ object AggregateUtil extends Enumeration {
private def extractDistinctInformation( private def extractDistinctInformation(
needDistinctInfo: Boolean, needDistinctInfo: Boolean,
aggCalls: Seq[AggregateCall], aggCalls: Seq[AggregateCall],
inputType: RelDataType, inputType: RowType,
hasStateBackedDataViews: Boolean, hasStateBackedDataViews: Boolean,
consumeRetraction: Boolean): (Array[DistinctInfo], Seq[AggregateCall]) = { consumeRetraction: Boolean): (Array[DistinctInfo], Seq[AggregateCall]) = {
...@@ -621,8 +619,7 @@ object AggregateUtil extends Enumeration { ...@@ -621,8 +619,7 @@ object AggregateUtil extends Enumeration {
if (call.isDistinct && !call.isApproximate && argIndexes.length > 0) { if (call.isDistinct && !call.isApproximate && argIndexes.length > 0) {
val argTypes: Array[LogicalType] = call val argTypes: Array[LogicalType] = call
.getArgList .getArgList
.map(inputType.getFieldList.get(_).getType) .map(inputType.getChildren.get(_))
.map(FlinkTypeFactory.toLogicalType)
.toArray .toArray
val keyType = createDistinctKeyType(argTypes) val keyType = createDistinctKeyType(argTypes)
...@@ -790,9 +787,9 @@ object AggregateUtil extends Enumeration { ...@@ -790,9 +787,9 @@ object AggregateUtil extends Enumeration {
*/ */
def getNeedRetractions( def getNeedRetractions(
groupCount: Int, groupCount: Int,
aggCalls: Seq[AggregateCall],
needRetraction: Boolean, needRetraction: Boolean,
monotonicity: RelModifiedMonotonicity, monotonicity: RelModifiedMonotonicity): Array[Boolean] = {
aggCalls: Seq[AggregateCall]): Array[Boolean] = {
val needRetractionArray = Array.fill(aggCalls.size)(needRetraction) val needRetractionArray = Array.fill(aggCalls.size)(needRetraction)
if (monotonicity != null && needRetraction) { if (monotonicity != null && needRetraction) {
aggCalls.zipWithIndex.foreach { case (aggCall, idx) => aggCalls.zipWithIndex.foreach { case (aggCall, idx) =>
......
...@@ -952,7 +952,9 @@ class FlinkRelMdHandlerTestBase { ...@@ -952,7 +952,9 @@ class FlinkRelMdHandlerTestBase {
val aggCalls = logicalAgg.getAggCallList val aggCalls = logicalAgg.getAggCallList
val aggFunctionFactory = new AggFunctionFactory( val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map { val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
} }
...@@ -1018,11 +1020,11 @@ class FlinkRelMdHandlerTestBase { ...@@ -1018,11 +1020,11 @@ class FlinkRelMdHandlerTestBase {
isMerge = false) isMerge = false)
val needRetractionArray = AggregateUtil.getNeedRetractions( val needRetractionArray = AggregateUtil.getNeedRetractions(
1, needRetraction = false, null, aggCalls) 1, aggCalls, needRetraction = false, null)
val localAggInfoList = transformToStreamAggregateInfoList( val localAggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType),
aggCalls, aggCalls,
studentStreamScan.getRowType,
needRetractionArray, needRetractionArray,
needInputCount = false, needInputCount = false,
isStateBackendDataViews = false) isStateBackendDataViews = false)
...@@ -1039,8 +1041,8 @@ class FlinkRelMdHandlerTestBase { ...@@ -1039,8 +1041,8 @@ class FlinkRelMdHandlerTestBase {
val streamExchange1 = new StreamPhysicalExchange( val streamExchange1 = new StreamPhysicalExchange(
cluster, streamLocalAgg.getTraitSet.replace(hash0), streamLocalAgg, hash0) cluster, streamLocalAgg.getTraitSet.replace(hash0), streamLocalAgg, hash0)
val globalAggInfoList = transformToStreamAggregateInfoList( val globalAggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(streamExchange1.getRowType),
aggCalls, aggCalls,
streamExchange1.getRowType,
needRetractionArray, needRetractionArray,
needInputCount = false, needInputCount = false,
isStateBackendDataViews = true) isStateBackendDataViews = true)
...@@ -1103,7 +1105,9 @@ class FlinkRelMdHandlerTestBase { ...@@ -1103,7 +1105,9 @@ class FlinkRelMdHandlerTestBase {
call => call.getAggregation != FlinkSqlOperatorTable.AUXILIARY_GROUP call => call.getAggregation != FlinkSqlOperatorTable.AUXILIARY_GROUP
} }
val aggFunctionFactory = new AggFunctionFactory( val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map { val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
} }
...@@ -1245,7 +1249,8 @@ class FlinkRelMdHandlerTestBase { ...@@ -1245,7 +1249,8 @@ class FlinkRelMdHandlerTestBase {
cluster, batchPhysicalTraits.replace(hash01), batchCalc, hash01) cluster, batchPhysicalTraits.replace(hash01), batchCalc, hash01)
val (_, _, aggregates) = val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions( AggregateUtil.transformToBatchAggregateFunctions(
flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType) FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
flinkLogicalWindowAgg.getAggCallList)
val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
val localWindowAggTypes = val localWindowAggTypes =
...@@ -1390,7 +1395,8 @@ class FlinkRelMdHandlerTestBase { ...@@ -1390,7 +1395,8 @@ class FlinkRelMdHandlerTestBase {
cluster, batchPhysicalTraits.replace(hash1), batchCalc, hash1) cluster, batchPhysicalTraits.replace(hash1), batchCalc, hash1)
val (_, _, aggregates) = val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions( AggregateUtil.transformToBatchAggregateFunctions(
flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType) FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
flinkLogicalWindowAgg.getAggCallList)
val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
val localWindowAggTypes = val localWindowAggTypes =
...@@ -1538,7 +1544,8 @@ class FlinkRelMdHandlerTestBase { ...@@ -1538,7 +1544,8 @@ class FlinkRelMdHandlerTestBase {
val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1) val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1)
val (_, _, aggregates) = val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions( AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroup, batchExchange1.getRowType) FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
aggCallsWithoutAuxGroup)
val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates) val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates)
val localWindowAggTypes = val localWindowAggTypes =
...@@ -2438,7 +2445,9 @@ class FlinkRelMdHandlerTestBase { ...@@ -2438,7 +2445,9 @@ class FlinkRelMdHandlerTestBase {
).build().asInstanceOf[LogicalAggregate] ).build().asInstanceOf[LogicalAggregate]
val aggCalls = logicalAgg.getAggCallList val aggCalls = logicalAgg.getAggCallList
val aggFunctionFactory = new AggFunctionFactory( val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false)) FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map { val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册