提交 53a888c5 编写于 作者: G godfreyhe

[FLINK-20737][table-planner-blink] Use RowType instead of RelDataType when building aggregate info

This closes #14478
上级 5351203f
......@@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate;
......@@ -106,7 +107,9 @@ public class BatchExecPythonAggregateRule extends ConverterRule {
Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions =
AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType(), null);
FlinkTypeFactory.toLogicalRowType(input.getRowType()),
aggCallsWithoutAuxGroupCalls,
null);
UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();
RelTraitSet requiredTraitSet =
......
......@@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkRelFactories;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.logical.LogicalWindow;
import org.apache.flink.table.planner.plan.logical.SessionGroupWindow;
import org.apache.flink.table.planner.plan.logical.SlidingGroupWindow;
......@@ -121,7 +122,9 @@ public class BatchExecPythonWindowAggregateRule extends RelOptRule {
Tuple3<int[][], DataType[][], UserDefinedFunction[]> aggBufferTypesAndFunctions =
AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType(), null);
FlinkTypeFactory.toLogicalRowType(input.getRowType()),
aggCallsWithoutAuxGroupCalls,
null);
UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();
int inputTimeFieldIndex =
......
......@@ -682,8 +682,8 @@ class MatchCodeGenerator(
matchAgg.inputExprs.indices.map(i => s"TMP$i"))
val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRelType),
aggCalls,
inputRelType,
needRetraction,
needInputCount = false,
isStateBackendDataViews = false,
......
......@@ -114,7 +114,7 @@ abstract class BatchExecHashAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType)
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) {
......
......@@ -113,7 +113,7 @@ abstract class BatchExecHashWindowAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType)
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val groupBufferLimitSize = config.getConfiguration.getInteger(
ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT)
......
......@@ -140,10 +140,10 @@ class BatchExecOverAggregate(
//operator needn't cache data
val aggHandlers = modeToGroupToAggCallToAggFunction.map { case (_, _, aggCallToAggFunction) =>
val aggInfoList = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1),
// use aggInputType which considers constants as input instead of inputType
inputTypeWithConstants,
orderKeyIndices)
FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
aggCallToAggFunction.map(_._1),
orderKeyIndexes = orderKeyIndices)
val codeGenCtx = CodeGeneratorContext(config)
val generator = new AggsHandlerCodeGenerator(
codeGenCtx,
......@@ -191,10 +191,10 @@ class BatchExecOverAggregate(
//lies on the offset of the window frame.
aggCallToAggFunction.map { case (aggCall, _) =>
val aggInfoList = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
Seq(aggCall),
inputTypeWithConstants,
orderKeyIndices,
Array[Boolean](true) /* needRetraction = true, See LeadLagAggFunction */)
Array[Boolean](true), /* needRetraction = true, See LeadLagAggFunction */
orderKeyIndexes = orderKeyIndices)
val generator = new AggsHandlerCodeGenerator(
CodeGeneratorContext(config),
......@@ -263,10 +263,10 @@ class BatchExecOverAggregate(
case _ =>
val aggInfoList = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1),
//use aggInputType which considers constants as input instead of inputSchema.relDataType
inputTypeWithConstants,
orderKeyIndices)
FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
aggCallToAggFunction.map(_._1),
orderKeyIndexes = orderKeyIndices)
val codeGenCtx = CodeGeneratorContext(config)
val generator = new AggsHandlerCodeGenerator(
codeGenCtx,
......
......@@ -95,7 +95,7 @@ abstract class BatchExecSortAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType)
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
......
......@@ -101,7 +101,7 @@ abstract class BatchExecSortWindowAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
aggCallToAggFunction.map(_._1), aggInputRowType)
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1))
val groupBufferLimitSize = planner.getTableConfig.getConfiguration.getInteger(
ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT)
......
......@@ -63,8 +63,8 @@ class StreamExecGroupAggregate(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this,
aggCalls,
grouping)
grouping.length,
aggCalls)
override def requireWatermark: Boolean = false
......
......@@ -42,8 +42,8 @@ abstract class StreamExecGroupTableAggregateBase(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this,
aggCalls,
grouping)
grouping.length,
aggCalls)
override def requireWatermark: Boolean = false
......
......@@ -137,8 +137,8 @@ abstract class StreamExecGroupWindowAggregateBase(
val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this)
val aggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowType),
aggCalls,
inputRowType,
Array.fill(aggCalls.size)(needRetraction),
needInputCount = needRetraction,
isStateBackendDataViews = true)
......
......@@ -243,9 +243,9 @@ class StreamExecOverAggregate(
val needRetraction = false
val aggInfoList = transformToStreamAggregateInfoList(
aggregateCalls,
// use aggInputType which considers constants as input instead of inputSchema.relDataType
aggInputType,
FlinkTypeFactory.toLogicalRowType(aggInputType),
aggregateCalls,
Array.fill(aggregateCalls.size)(needRetraction),
needInputCount = needRetraction,
isStateBackendDataViews = true)
......@@ -322,9 +322,9 @@ class StreamExecOverAggregate(
val needRetraction = true
val aggInfoList = transformToStreamAggregateInfoList(
aggregateCalls,
// use aggInputType which considers constants as input instead of inputSchema.relDataType
aggInputType,
FlinkTypeFactory.toLogicalRowType(aggInputType),
aggregateCalls,
Array.fill(aggregateCalls.size)(needRetraction),
needInputCount = needRetraction,
isStateBackendDataViews = true)
......
......@@ -59,8 +59,8 @@ class StreamExecPythonGroupAggregate(
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this,
aggCalls,
grouping)
grouping.length,
aggCalls)
override def requireWatermark: Boolean = false
......
......@@ -154,7 +154,7 @@ trait BatchExecAggRuleBase {
protected def isAggBufferFixedLength(agg: Aggregate): Boolean = {
val (_, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, agg.getInput.getRowType)
FlinkTypeFactory.toLogicalRowType(agg.getInput.getRowType), aggCallsWithoutAuxGroupCalls)
isAggBufferFixedLength(aggBufferTypes.map(_.map(fromDataTypeToLogicalType)))
}
......
......@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.config.OptimizerConfigOptions
import org.apache.flink.table.planner.calcite.FlinkContext
import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
......@@ -87,7 +87,7 @@ class BatchExecHashAggRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, inputRowType)
FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls)
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
val aggProvidedTraitSet = agg.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
......
......@@ -94,7 +94,9 @@ class BatchExecOverAggregateRule
val groupToAggCallToAggFunction = groupBuffer.map { group =>
val aggregateCalls = group.getAggregateCalls(logicWindow)
val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
aggregateCalls, inputTypeWithConstants, orderKeyIndexes)
FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
aggregateCalls,
orderKeyIndexes)
val aggCallToAggFunction = aggregateCalls.zip(aggregates)
(group, aggCallToAggFunction)
}
......
......@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.config.OptimizerConfigOptions
import org.apache.flink.table.planner.calcite.FlinkContext
import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
......@@ -80,7 +80,7 @@ class BatchExecSortAggRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggFunctions) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, inputRowType)
FlinkTypeFactory.toLogicalRowType(inputRowType), aggCallsWithoutAuxGroupCalls)
val groupSet = agg.getGroupSet.toArray
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggFunctions)
// TODO aggregate include projection now, so do not provide new trait will be safe
......
......@@ -101,7 +101,7 @@ class BatchExecWindowAggregateRule
val (auxGroupSet, aggCallsWithoutAuxGroupCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
val (_, aggBufferTypes, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroupCalls, input.getRowType)
FlinkTypeFactory.toLogicalRowType(input.getRowType), aggCallsWithoutAuxGroupCalls)
val aggCallToAggFunction = aggCallsWithoutAuxGroupCalls.zip(aggregates)
val internalAggBufferTypes = aggBufferTypes.map(_.map(fromDataTypeToLogicalType))
val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
......
......@@ -19,9 +19,10 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExpand, BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand}
......@@ -75,7 +76,7 @@ abstract class EnforceLocalAggRuleBase(
val aggCallToAggFunction = completeAgg.getAggCallToAggFunction
val (_, aggBufferTypes, _) = AggregateUtil.transformToBatchAggregateFunctions(
aggCalls, inputRowType)
FlinkTypeFactory.toLogicalRowType(inputRowType), aggCalls)
val traitSet = cluster.getPlanner
.emptyTraitSet
......
......@@ -132,9 +132,9 @@ class IncrementalAggregateRule
} else {
// an additional count1 is inserted, need to adapt the global agg
val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
aggCalls,
// the final agg input is partial agg
partialGlobalAgg.getRowType,
FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType),
aggCalls,
// all the aggs do not need retraction
Array.fill(aggCalls.length)(false),
// also do not need count*
......@@ -142,9 +142,9 @@ class IncrementalAggregateRule
// the local agg is not works on state
isStateBackendDataViews = false)
val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
aggCalls,
// the final agg input is partial agg
partialGlobalAgg.getRowType,
FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType),
aggCalls,
// all the aggs do not need retraction
Array.fill(aggCalls.length)(false),
// also do not need count*
......
......@@ -66,11 +66,11 @@ class TwoStageOptimizedAggregateRule extends RelOptRule(
val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = AggregateUtil.getNeedRetractions(
agg.grouping.length, needRetraction, monotonicity, agg.aggCalls)
agg.grouping.length, agg.aggCalls, needRetraction, monotonicity)
val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType( agg.getInput.getRowType),
agg.aggCalls,
agg.getInput.getRowType,
needRetractionArray,
needRetraction,
isStateBackendDataViews = true)
......@@ -98,18 +98,18 @@ class TwoStageOptimizedAggregateRule extends RelOptRule(
val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = AggregateUtil.getNeedRetractions(
agg.grouping.length, needRetraction, monotonicity, agg.aggCalls)
agg.grouping.length, agg.aggCalls, needRetraction, monotonicity)
val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(realInput.getRowType),
agg.aggCalls,
realInput.getRowType,
needRetractionArray,
needRetraction,
isStateBackendDataViews = false)
val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(realInput.getRowType),
agg.aggCalls,
realInput.getRowType,
needRetractionArray,
needRetraction,
isStateBackendDataViews = true)
......
......@@ -19,24 +19,17 @@ package org.apache.flink.table.planner.plan.utils
import org.apache.flink.table.api.TableException
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.aggfunctions.FirstValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.FirstValueWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.IncrSumAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.IncrSumWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.LastValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.LastValueWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.SingleValueAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction._
import org.apache.flink.table.planner.functions.aggfunctions._
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction}
import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.runtime.typeutils.DecimalDataTypeInfo
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction}
......@@ -49,14 +42,14 @@ import scala.collection.JavaConversions._
* The class of agg function factory which is used to create AggregateFunction or
* DeclarativeAggregateFunction from Calcite AggregateCall
*
* @param inputType the input rel data type
* @param orderKeyIdx the indexes of order key (null when is not over agg)
* @param needRetraction true if need retraction
* @param inputRowType the input's output RowType
* @param orderKeyIndexes the indexes of order key (null when is not over agg)
* @param aggCallNeedRetractions true if need retraction
*/
class AggFunctionFactory(
inputType: RelDataType,
orderKeyIdx: Array[Int],
needRetraction: Array[Boolean]) {
inputRowType: RowType,
orderKeyIndexes: Array[Int],
aggCallNeedRetractions: Array[Boolean]) {
/**
* The entry point to create an aggregate function from the given AggregateCall
......@@ -64,8 +57,7 @@ class AggFunctionFactory(
def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = {
val argTypes: Array[LogicalType] = call.getArgList
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
.map(inputRowType.getChildren.get(_))
.toArray
call.getAggregation match {
......@@ -165,7 +157,7 @@ class AggFunctionFactory(
private def createSumAggFunction(
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
argTypes(0).getTypeRoot match {
case TINYINT =>
new ByteSumWithRetractAggFunction
......@@ -236,7 +228,7 @@ class AggFunctionFactory(
private def createIncrSumAggFunction(
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
argTypes(0).getTypeRoot match {
case TINYINT =>
new ByteIncrSumWithRetractAggFunction
......@@ -286,7 +278,7 @@ class AggFunctionFactory(
index: Int)
: UserDefinedFunction = {
val valueType = argTypes(0)
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE =>
......@@ -370,7 +362,7 @@ class AggFunctionFactory(
index: Int)
: UserDefinedFunction = {
val valueType = argTypes(0)
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL |
TIME_WITHOUT_TIME_ZONE | DATE | TIMESTAMP_WITHOUT_TIME_ZONE =>
......@@ -460,16 +452,12 @@ class AggFunctionFactory(
}
private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIdx
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new RankAggFunction(argTypes)
}
private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIdx
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new DenseRankAggFunction(argTypes)
}
......@@ -478,7 +466,7 @@ class AggFunctionFactory(
index: Int)
: UserDefinedFunction = {
val valueType = argTypes(0)
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL =>
new FirstValueWithRetractAggFunction(valueType)
......@@ -502,7 +490,7 @@ class AggFunctionFactory(
index: Int)
: UserDefinedFunction = {
val valueType = argTypes(0)
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
valueType.getTypeRoot match {
case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | BOOLEAN | VARCHAR | DECIMAL =>
new LastValueWithRetractAggFunction(valueType)
......@@ -524,7 +512,7 @@ class AggFunctionFactory(
private def createListAggFunction(
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
new ListAggWithRetractAggFunction
} else {
new ListAggFunction(1)
......@@ -534,7 +522,7 @@ class AggFunctionFactory(
private def createListAggWsFunction(
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (needRetraction(index)) {
if (aggCallNeedRetractions(index)) {
new ListAggWsWithRetractAggFunction
} else {
new ListAggFunction(2)
......
......@@ -151,12 +151,12 @@ object AggregateUtil extends Enumeration {
def getOutputIndexToAggCallIndexMap(
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
orderKeyIdx: Array[Int] = null): util.Map[Integer, Integer] = {
orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = {
val aggInfos = transformToAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputType),
aggregateCalls,
inputType,
orderKeyIdx,
Array.fill(aggregateCalls.size)(false),
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos
......@@ -176,10 +176,10 @@ object AggregateUtil extends Enumeration {
}
def deriveAggregateInfoList(
aggNode: StreamPhysicalRel,
aggCalls: Seq[AggregateCall],
grouping: Array[Int]): AggregateInfoList = {
val input = aggNode.getInput(0)
agg: StreamPhysicalRel,
groupCount: Int,
aggCalls: Seq[AggregateCall]): AggregateInfoList = {
val input = agg.getInput(0)
// need to call `retract()` if input contains update or delete
val modifyKindSetTrait = input.getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE)
val needRetraction = if (modifyKindSetTrait == null) {
......@@ -188,29 +188,28 @@ object AggregateUtil extends Enumeration {
} else {
!modifyKindSetTrait.modifyKindSet.isInsertOnly
}
val fmq = FlinkRelMetadataQuery.reuseOrCreate(aggNode.getCluster.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(aggNode)
val needRetractionArray = AggregateUtil.getNeedRetractions(
grouping.length, needRetraction, monotonicity, aggCalls)
AggregateUtil.transformToStreamAggregateInfoList(
val fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster.getMetadataQuery)
val monotonicity = fmq.getRelModifiedMonotonicity(agg)
val needRetractionArray = getNeedRetractions(groupCount, aggCalls, needRetraction, monotonicity)
transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(input.getRowType),
aggCalls,
input.getRowType,
needRetractionArray,
needInputCount = needRetraction,
isStateBackendDataViews = true)
}
def transformToBatchAggregateFunctions(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType,
orderKeyIdx: Array[Int] = null)
orderKeyIndexes: Array[Int] = null)
: (Array[Array[Int]], Array[Array[DataType]], Array[UserDefinedFunction]) = {
val aggInfos = transformToAggregateInfoList(
aggregateCalls,
inputRowType,
orderKeyIdx,
aggregateCalls,
Array.fill(aggregateCalls.size)(false),
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos
......@@ -223,39 +222,39 @@ object AggregateUtil extends Enumeration {
}
def transformToBatchAggregateInfoList(
aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType,
orderKeyIdx: Array[Int] = null,
needRetractions: Array[Boolean] = null): AggregateInfoList = {
inputRowType: RowType,
aggCalls: Seq[AggregateCall],
aggCallNeedRetractions: Array[Boolean] = null,
orderKeyIndexes: Array[Int] = null): AggregateInfoList = {
val needRetractionArray = if (needRetractions == null) {
Array.fill(aggregateCalls.size)(false)
val finalAggCallNeedRetractions = if (aggCallNeedRetractions == null) {
Array.fill(aggCalls.size)(false)
} else {
needRetractions
aggCallNeedRetractions
}
transformToAggregateInfoList(
aggregateCalls,
inputRowType,
orderKeyIdx,
needRetractionArray,
aggCalls,
finalAggCallNeedRetractions,
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false)
}
def transformToStreamAggregateInfoList(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType,
needRetraction: Array[Boolean],
aggCallNeedRetractions: Array[Boolean],
needInputCount: Boolean,
isStateBackendDataViews: Boolean,
needDistinctInfo: Boolean = true): AggregateInfoList = {
transformToAggregateInfoList(
aggregateCalls,
inputRowType,
orderKeyIdx = null,
needRetraction ++ Array(needInputCount), // for additional count(*)
aggregateCalls,
aggCallNeedRetractions ++ Array(needInputCount), // for additional count(*)
orderKeyIndexes = null,
needInputCount,
isStateBackendDataViews,
needDistinctInfo)
......@@ -264,10 +263,10 @@ object AggregateUtil extends Enumeration {
/**
* Transforms calcite aggregate calls to AggregateInfos.
*
* @param inputRowType the input's output RowType
* @param aggregateCalls the calcite aggregate calls
* @param inputRowType the input rel data type
* @param orderKeyIdx the index of order by field in the input, null if not over agg
* @param needRetraction whether the aggregate function need retract method
* @param aggCallNeedRetractions whether the aggregate function need retract method
* @param orderKeyIndexes the index of order by field in the input, null if not over agg
* @param needInputCount whether need to calculate the input counts, which is used in
* aggregation with retraction input.If needed,
* insert a count(1) aggregate into the agg list.
......@@ -275,10 +274,10 @@ object AggregateUtil extends Enumeration {
* @param needDistinctInfo whether need to extract distinct information
*/
private def transformToAggregateInfoList(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall],
inputRowType: RelDataType,
orderKeyIdx: Array[Int],
needRetraction: Array[Boolean],
aggCallNeedRetractions: Array[Boolean],
orderKeyIndexes: Array[Int],
needInputCount: Boolean,
isStateBackedDataViews: Boolean,
needDistinctInfo: Boolean): AggregateInfoList = {
......@@ -301,12 +300,12 @@ object AggregateUtil extends Enumeration {
// Step-3:
// create aggregate information
val factory = new AggFunctionFactory(inputRowType, orderKeyIdx, needRetraction)
val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions)
val aggInfos = newAggCalls
.zipWithIndex
.map { case (call, index) =>
val argIndexes = call.getAggregation match {
case _: SqlRankFunction => orderKeyIdx
case _: SqlRankFunction => orderKeyIndexes
case _ => call.getArgList.map(_.intValue()).toArray
}
transformToAggregateInfo(
......@@ -316,14 +315,14 @@ object AggregateUtil extends Enumeration {
argIndexes,
factory.createAggFunction(call, index),
isStateBackedDataViews,
needRetraction(index))
aggCallNeedRetractions(index))
}
AggregateInfoList(aggInfos.toArray, indexOfCountStar, countStarInserted, distinctInfos)
}
private def transformToAggregateInfo(
inputRowRelDataType: RelDataType,
inputRowType: RowType,
call: AggregateCall,
index: Int,
argIndexes: Array[Int],
......@@ -334,7 +333,7 @@ object AggregateUtil extends Enumeration {
case _: BridgingSqlAggFunction =>
createAggregateInfoFromBridgingFunction(
inputRowRelDataType,
inputRowType,
call,
index,
argIndexes,
......@@ -344,7 +343,7 @@ object AggregateUtil extends Enumeration {
case _: AggSqlFunction =>
createAggregateInfoFromLegacyFunction(
inputRowRelDataType,
inputRowType,
call,
index,
argIndexes,
......@@ -363,7 +362,7 @@ object AggregateUtil extends Enumeration {
}
private def createAggregateInfoFromBridgingFunction(
inputRowRelDataType: RelDataType,
inputRowType: RowType,
call: AggregateCall,
index: Int,
argIndexes: Array[Int],
......@@ -387,7 +386,7 @@ object AggregateUtil extends Enumeration {
function.getTypeFactory,
function,
SqlTypeUtil.projectTypes(
inputRowRelDataType,
FlinkTypeFactory.INSTANCE.buildRelNodeRowType(inputRowType),
argIndexes.map(Int.box).toList),
0,
false))
......@@ -490,7 +489,7 @@ object AggregateUtil extends Enumeration {
}
private def createAggregateInfoFromLegacyFunction(
inputRowRelDataType: RelDataType,
inputRowType: RowType,
call: AggregateCall,
index: Int,
argIndexes: Array[Int],
......@@ -507,8 +506,7 @@ object AggregateUtil extends Enumeration {
}
val externalAccType = getAccumulatorTypeOfAggregateFunction(a, implicitAccType)
val argTypes = call.getArgList
.map(idx => inputRowRelDataType.getFieldList.get(idx).getType)
.map(FlinkTypeFactory.toLogicalType)
.map(idx => inputRowType.getChildren.get(idx))
val externalArgTypes: Array[DataType] = getAggUserDefinedInputTypes(
a,
externalAccType,
......@@ -605,7 +603,7 @@ object AggregateUtil extends Enumeration {
private def extractDistinctInformation(
needDistinctInfo: Boolean,
aggCalls: Seq[AggregateCall],
inputType: RelDataType,
inputType: RowType,
hasStateBackedDataViews: Boolean,
consumeRetraction: Boolean): (Array[DistinctInfo], Seq[AggregateCall]) = {
......@@ -621,8 +619,7 @@ object AggregateUtil extends Enumeration {
if (call.isDistinct && !call.isApproximate && argIndexes.length > 0) {
val argTypes: Array[LogicalType] = call
.getArgList
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
.map(inputType.getChildren.get(_))
.toArray
val keyType = createDistinctKeyType(argTypes)
......@@ -790,9 +787,9 @@ object AggregateUtil extends Enumeration {
*/
def getNeedRetractions(
groupCount: Int,
aggCalls: Seq[AggregateCall],
needRetraction: Boolean,
monotonicity: RelModifiedMonotonicity,
aggCalls: Seq[AggregateCall]): Array[Boolean] = {
monotonicity: RelModifiedMonotonicity): Array[Boolean] = {
val needRetractionArray = Array.fill(aggCalls.size)(needRetraction)
if (monotonicity != null && needRetraction) {
aggCalls.zipWithIndex.foreach { case (aggCall, idx) =>
......
......@@ -952,7 +952,9 @@ class FlinkRelMdHandlerTestBase {
val aggCalls = logicalAgg.getAggCallList
val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false))
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
......@@ -1018,11 +1020,11 @@ class FlinkRelMdHandlerTestBase {
isMerge = false)
val needRetractionArray = AggregateUtil.getNeedRetractions(
1, needRetraction = false, null, aggCalls)
1, aggCalls, needRetraction = false, null)
val localAggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(studentStreamScan.getRowType),
aggCalls,
studentStreamScan.getRowType,
needRetractionArray,
needInputCount = false,
isStateBackendDataViews = false)
......@@ -1039,8 +1041,8 @@ class FlinkRelMdHandlerTestBase {
val streamExchange1 = new StreamPhysicalExchange(
cluster, streamLocalAgg.getTraitSet.replace(hash0), streamLocalAgg, hash0)
val globalAggInfoList = transformToStreamAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(streamExchange1.getRowType),
aggCalls,
streamExchange1.getRowType,
needRetractionArray,
needInputCount = false,
isStateBackendDataViews = true)
......@@ -1103,7 +1105,9 @@ class FlinkRelMdHandlerTestBase {
call => call.getAggregation != FlinkSqlOperatorTable.AUXILIARY_GROUP
}
val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false))
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
......@@ -1245,7 +1249,8 @@ class FlinkRelMdHandlerTestBase {
cluster, batchPhysicalTraits.replace(hash01), batchCalc, hash01)
val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions(
flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType)
FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
flinkLogicalWindowAgg.getAggCallList)
val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
val localWindowAggTypes =
......@@ -1390,7 +1395,8 @@ class FlinkRelMdHandlerTestBase {
cluster, batchPhysicalTraits.replace(hash1), batchCalc, hash1)
val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions(
flinkLogicalWindowAgg.getAggCallList, batchExchange1.getRowType)
FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
flinkLogicalWindowAgg.getAggCallList)
val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates)
val localWindowAggTypes =
......@@ -1538,7 +1544,8 @@ class FlinkRelMdHandlerTestBase {
val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1)
val (_, _, aggregates) =
AggregateUtil.transformToBatchAggregateFunctions(
aggCallsWithoutAuxGroup, batchExchange1.getRowType)
FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType),
aggCallsWithoutAuxGroup)
val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates)
val localWindowAggTypes =
......@@ -2438,7 +2445,9 @@ class FlinkRelMdHandlerTestBase {
).build().asInstanceOf[LogicalAggregate]
val aggCalls = logicalAgg.getAggCallList
val aggFunctionFactory = new AggFunctionFactory(
studentBatchScan.getRowType, Array.empty[Int], Array.fill(aggCalls.size())(false))
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册