提交 c96d1fff 编写于 作者: G godfreyhe

[FLINK-20737][table-planner-blink] Introduce...

[FLINK-20737][table-planner-blink] Introduce StreamPhysicalGroupTableAggregate, and make StreamExecGroupTableAggregate only extended from ExecNode

This closes #14478
上级 1eaf54b7
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.stream;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction;
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;
import org.apache.calcite.rel.core.AggregateCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.Collections;
/** Stream {@link ExecNode} for unbounded java/scala group table aggregate. */
public class StreamExecGroupTableAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData> {
private static final Logger LOG = LoggerFactory.getLogger(StreamExecGroupTableAggregate.class);
private final int[] grouping;
private final AggregateCall[] aggCalls;
/** Each element indicates whether the corresponding agg call needs `retract` method. */
private final boolean[] aggCallNeedRetractions;
/** Whether this node will generate UPDATE_BEFORE messages. */
private final boolean generateUpdateBefore;
/** Whether this node consumes retraction messages. */
private final boolean needRetraction;
public StreamExecGroupTableAggregate(
int[] grouping,
AggregateCall[] aggCalls,
boolean[] aggCallNeedRetractions,
boolean generateUpdateBefore,
boolean needRetraction,
ExecEdge inputEdge,
RowType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
Preconditions.checkArgument(aggCalls.length == aggCallNeedRetractions.length);
this.grouping = grouping;
this.aggCalls = aggCalls;
this.aggCallNeedRetractions = aggCallNeedRetractions;
this.generateUpdateBefore = generateUpdateBefore;
this.needRetraction = needRetraction;
}
@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final TableConfig tableConfig = planner.getTableConfig();
if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime() < 0) {
LOG.warn(
"No state retention interval configured for a query which accumulates state. "
+ "Please provide a query configuration with valid retention interval to prevent excessive "
+ "state size. You may specify a retention time of 0 to not clean up the state.");
}
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner);
final RowType inputRowType = (RowType) inputNode.getOutputType();
final AggsHandlerCodeGenerator generator =
new AggsHandlerCodeGenerator(
new CodeGeneratorContext(tableConfig),
planner.getRelBuilder(),
JavaScalaConversionUtil.toScala(inputRowType.getChildren()),
// TODO: heap state backend do not copy key currently,
// we have to copy input field
// TODO: copy is not need when state backend is rocksdb,
// improve this in future
// TODO: but other operators do not copy this input field.....
true)
.needAccumulate();
if (needRetraction) {
generator.needRetract();
}
final AggregateInfoList aggInfoList =
AggregateUtil.transformToStreamAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
aggCallNeedRetractions,
needRetraction,
true,
true);
final GeneratedTableAggsHandleFunction aggsHandler =
generator.generateTableAggsHandler("GroupTableAggHandler", aggInfoList);
final LogicalType[] accTypes =
Arrays.stream(aggInfoList.getAccTypes())
.map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType)
.toArray(LogicalType[]::new);
final int inputCountIndex = aggInfoList.getIndexOfCountStar();
final GroupTableAggFunction aggFunction =
new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
tableConfig.getIdleStateRetention().toMillis());
final OneInputStreamOperator<RowData, RowData> operator =
new KeyedProcessOperator<>(aggFunction);
// partitioned aggregation
final OneInputTransformation<RowData, RowData> transform =
new OneInputTransformation<>(
inputTransform,
"GroupTableAggregate",
operator,
InternalTypeInfo.of(getOutputType()),
inputTransform.getParallelism());
if (inputsContainSingleton()) {
transform.setParallelism(1);
transform.setMaxParallelism(1);
}
// set KeyType and Selector for state
final RowDataKeySelector selector =
KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType));
transform.setStateKeySelector(selector);
transform.setStateKeyType(selector.getProducedType());
return transform;
}
}
...@@ -446,7 +446,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { ...@@ -446,7 +446,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
* @return interval of the given column on stream group TableAggregate * @return interval of the given column on stream group TableAggregate
*/ */
def getColumnInterval( def getColumnInterval(
aggregate: StreamExecGroupTableAggregate, aggregate: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery, mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
...@@ -550,7 +550,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { ...@@ -550,7 +550,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: TableAggregate => agg.getGroupSet.toArray case agg: TableAggregate => agg.getGroupSet.toArray
case agg: StreamExecGroupTableAggregate => agg.grouping case agg: StreamPhysicalGroupTableAggregate => agg.grouping
case agg: StreamExecGroupWindowTableAggregate => agg.getGrouping case agg: StreamExecGroupWindowTableAggregate => agg.getGrouping
} }
......
...@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata ...@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval
import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalGroupTableAggregate, StreamPhysicalLocalGroupAggregate}
import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.stats.ValueInterval
import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil
import org.apache.flink.util.Preconditions.checkArgument import org.apache.flink.util.Preconditions.checkArgument
...@@ -192,7 +192,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC ...@@ -192,7 +192,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
} }
def getFilteredColumnInterval( def getFilteredColumnInterval(
aggregate: StreamExecGroupTableAggregate, aggregate: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery, mq: RelMetadataQuery,
columnIndex: Int, columnIndex: Int,
filterArg: Int): ValueInterval = { filterArg: Int): ValueInterval = {
......
...@@ -282,7 +282,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -282,7 +282,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
rel: StreamExecGroupTableAggregate, rel: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = { mq: RelMetadataQuery): RelModifiedMonotonicity = {
getRelModifiedMonotonicityOnTableAggregate( getRelModifiedMonotonicityOnTableAggregate(
rel.getInput, rel.grouping, rel.getRowType.getFieldCount, mq) rel.getInput, rel.grouping, rel.getRowType.getFieldCount, mq)
......
...@@ -26,6 +26,7 @@ import org.apache.flink.table.data.RowData ...@@ -26,6 +26,7 @@ import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.delegation.StreamPlanner import org.apache.flink.table.planner.delegation.StreamPlanner
import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, KeySelectorUtil} import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, KeySelectorUtil}
import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec
...@@ -49,13 +50,14 @@ class StreamExecPythonGroupTableAggregate( ...@@ -49,13 +50,14 @@ class StreamExecPythonGroupTableAggregate(
outputRowType: RelDataType, outputRowType: RelDataType,
grouping: Array[Int], grouping: Array[Int],
aggCalls: Seq[AggregateCall]) aggCalls: Seq[AggregateCall])
extends StreamExecGroupTableAggregateBase( extends StreamPhysicalGroupTableAggregateBase(
cluster, cluster,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
grouping, grouping,
aggCalls) aggCalls)
with LegacyStreamExecNode[RowData]
with CommonExecPythonAggregate { with CommonExecPythonAggregate {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
......
...@@ -17,39 +17,29 @@ ...@@ -17,39 +17,29 @@
*/ */
package org.apache.flink.table.planner.plan.nodes.physical.stream package org.apache.flink.table.planner.plan.nodes.physical.stream
import org.apache.flink.api.dag.Transformation
import org.apache.flink.streaming.api.operators.KeyedProcessOperator
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecGroupTableAggregate
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.delegation.StreamPlanner import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ChangelogPlanUtils}
import org.apache.flink.table.planner.plan.utils._
import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.RelNode
import java.util import java.util
import scala.collection.JavaConversions._
/** /**
* Stream physical RelNode for unbounded java/scala group table aggregate. * Stream physical RelNode for unbounded java/scala group table aggregate.
*/ */
class StreamExecGroupTableAggregate( class StreamPhysicalGroupTableAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
grouping: Array[Int], grouping: Array[Int],
aggCalls: Seq[AggregateCall]) aggCalls: Seq[AggregateCall])
extends StreamExecGroupTableAggregateBase( extends StreamPhysicalGroupTableAggregateBase(
cluster, cluster,
traitSet, traitSet,
inputRel, inputRel,
...@@ -58,7 +48,7 @@ class StreamExecGroupTableAggregate( ...@@ -58,7 +48,7 @@ class StreamExecGroupTableAggregate(
aggCalls) { aggCalls) {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new StreamExecGroupTableAggregate( new StreamPhysicalGroupTableAggregate(
cluster, cluster,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
...@@ -67,74 +57,20 @@ class StreamExecGroupTableAggregate( ...@@ -67,74 +57,20 @@ class StreamExecGroupTableAggregate(
aggCalls) aggCalls)
} }
override protected def translateToPlanInternal( override def translateToExecNode(): ExecNode[_] = {
planner: StreamPlanner): Transformation[RowData] = { val aggCallNeedRetractions =
AggregateUtil.deriveAggCallNeedRetractions(this, grouping.length, aggCalls)
val tableConfig = planner.getTableConfig
if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) {
LOG.warn("No state retention interval configured for a query which accumulates state. " +
"Please provide a query configuration with valid retention interval to prevent excessive " +
"state size. You may specify a retention time of 0 to not clean up the state.")
}
val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType)
val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this)
val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this)
new StreamExecGroupTableAggregate(
val generator = new AggsHandlerCodeGenerator(
CodeGeneratorContext(tableConfig),
planner.getRelBuilder,
inputRowType.getChildren,
// TODO: heap state backend do not copy key currently, we have to copy input field
// TODO: copy is not need when state backend is rocksdb, improve this in future
// TODO: but other operators do not copy this input field.....
copyInputField = true)
if (needRetraction) {
generator.needRetract()
}
val aggsHandler = generator
.needAccumulate()
.generateTableAggsHandler("GroupTableAggHandler", aggInfoList)
val accTypes = aggInfoList.getAccTypes.map(fromDataTypeToLogicalType)
val inputCountIndex = aggInfoList.getIndexOfCountStar
val aggFunction = new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
tableConfig.getIdleStateRetention.toMillis)
val operator = new KeyedProcessOperator[RowData, RowData, RowData](aggFunction)
val selector = KeySelectorUtil.getRowDataSelector(
grouping, grouping,
InternalTypeInfo.of(inputRowType)) aggCalls.toArray,
aggCallNeedRetractions,
// partitioned aggregation generateUpdateBefore,
val ret = new OneInputTransformation( needRetraction,
inputTransformation, ExecEdge.DEFAULT,
"GroupTableAggregate", FlinkTypeFactory.toLogicalRowType(getRowType),
operator, getRelDetailedDescription
InternalTypeInfo.of(outRowType), )
inputTransformation.getParallelism)
if (inputsContainSingleton()) {
ret.setParallelism(1)
ret.setMaxParallelism(1)
}
// set KeyType and Selector for state
ret.setStateKeySelector(selector)
ret.setStateKeyType(selector.getProducedType)
ret
} }
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
*/ */
package org.apache.flink.table.planner.plan.nodes.physical.stream package org.apache.flink.table.planner.plan.nodes.physical.stream
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode
import org.apache.flink.table.planner.plan.utils._ import org.apache.flink.table.planner.plan.utils._
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
...@@ -29,7 +27,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} ...@@ -29,7 +27,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
/** /**
* Base Stream physical RelNode for unbounded group table aggregate. * Base Stream physical RelNode for unbounded group table aggregate.
*/ */
abstract class StreamExecGroupTableAggregateBase( abstract class StreamPhysicalGroupTableAggregateBase(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
...@@ -37,10 +35,9 @@ abstract class StreamExecGroupTableAggregateBase( ...@@ -37,10 +35,9 @@ abstract class StreamExecGroupTableAggregateBase(
val grouping: Array[Int], val grouping: Array[Int],
val aggCalls: Seq[AggregateCall]) val aggCalls: Seq[AggregateCall])
extends SingleRel(cluster, traitSet, inputRel) extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel with StreamPhysicalRel {
with LegacyStreamExecNode[RowData] {
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( protected val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this, this,
grouping.length, grouping.length,
aggCalls) aggCalls)
......
...@@ -178,7 +178,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti ...@@ -178,7 +178,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
val providedTrait = new ModifyKindSetTrait(builder.build()) val providedTrait = new ModifyKindSetTrait(builder.build())
createNewNode(agg, children, providedTrait, requiredTrait, requester) createNewNode(agg, children, providedTrait, requiredTrait, requester)
case tagg: StreamExecGroupTableAggregateBase => case tagg: StreamPhysicalGroupTableAggregateBase =>
// table agg support all changes in input // table agg support all changes in input
val children = visitChildren(tagg, ModifyKindSetTrait.ALL_CHANGES) val children = visitChildren(tagg, ModifyKindSetTrait.ALL_CHANGES)
// table aggregate will produce all changes, including deletions // table aggregate will produce all changes, including deletions
...@@ -461,7 +461,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti ...@@ -461,7 +461,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
} }
visitSink(sink, sinkRequiredTraits) visitSink(sink, sinkRequiredTraits)
case _: StreamPhysicalGroupAggregate | _: StreamExecGroupTableAggregate | case _: StreamPhysicalGroupAggregate | _: StreamPhysicalGroupTableAggregate |
_: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate | _: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate |
_: StreamExecPythonGroupTableAggregate => _: StreamExecPythonGroupTableAggregate =>
// Aggregate, TableAggregate and Limit requires update_before if there are updates // Aggregate, TableAggregate and Limit requires update_before if there are updates
......
...@@ -421,7 +421,7 @@ object FlinkStreamRuleSets { ...@@ -421,7 +421,7 @@ object FlinkStreamRuleSets {
StreamPhysicalExpandRule.INSTANCE, StreamPhysicalExpandRule.INSTANCE,
// group agg // group agg
StreamPhysicalGroupAggregateRule.INSTANCE, StreamPhysicalGroupAggregateRule.INSTANCE,
StreamExecGroupTableAggregateRule.INSTANCE, StreamPhysicalGroupTableAggregateRule.INSTANCE,
StreamPhysicalPythonGroupAggregateRule.INSTANCE, StreamPhysicalPythonGroupAggregateRule.INSTANCE,
StreamExecPythonGroupTableAggregateRule.INSTANCE, StreamExecPythonGroupTableAggregateRule.INSTANCE,
// over agg // over agg
......
...@@ -18,22 +18,26 @@ ...@@ -18,22 +18,26 @@
package org.apache.flink.table.planner.plan.rules.physical.stream package org.apache.flink.table.planner.plan.rules.physical.stream
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableAggregate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecGroupTableAggregate import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGroupTableAggregate
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
class StreamExecGroupTableAggregateRule extends ConverterRule( /**
* Rule to convert a [[FlinkLogicalTableAggregate]] into a [[StreamPhysicalGroupTableAggregate]].
*/
class StreamPhysicalGroupTableAggregateRule extends ConverterRule(
classOf[FlinkLogicalTableAggregate], classOf[FlinkLogicalTableAggregate],
FlinkConventions.LOGICAL, FlinkConventions.LOGICAL,
FlinkConventions.STREAM_PHYSICAL, FlinkConventions.STREAM_PHYSICAL,
"StreamExecGroupTableAggregateRule") { "StreamPhysicalGroupTableAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = { override def matches(call: RelOptRuleCall): Boolean = {
val agg: FlinkLogicalTableAggregate = call.rel(0) val agg: FlinkLogicalTableAggregate = call.rel(0)
...@@ -53,7 +57,7 @@ class StreamExecGroupTableAggregateRule extends ConverterRule( ...@@ -53,7 +57,7 @@ class StreamExecGroupTableAggregateRule extends ConverterRule(
val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
val newInput: RelNode = RelOptRule.convert(agg.getInput, requiredTraitSet) val newInput: RelNode = RelOptRule.convert(agg.getInput, requiredTraitSet)
new StreamExecGroupTableAggregate( new StreamPhysicalGroupTableAggregate(
rel.getCluster, rel.getCluster,
providedTraitSet, providedTraitSet,
newInput, newInput,
...@@ -64,7 +68,6 @@ class StreamExecGroupTableAggregateRule extends ConverterRule( ...@@ -64,7 +68,6 @@ class StreamExecGroupTableAggregateRule extends ConverterRule(
} }
} }
object StreamExecGroupTableAggregateRule { object StreamPhysicalGroupTableAggregateRule {
val INSTANCE: StreamExecGroupTableAggregateRule = new StreamExecGroupTableAggregateRule() val INSTANCE: StreamPhysicalGroupTableAggregateRule = new StreamPhysicalGroupTableAggregateRule()
} }
...@@ -835,7 +835,7 @@ class FlinkRelMdHandlerTestBase { ...@@ -835,7 +835,7 @@ class FlinkRelMdHandlerTestBase {
builder.add("f1", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER)) builder.add("f1", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
val relDataType = builder.build() val relDataType = builder.build()
val streamExecTableAgg = new StreamExecGroupTableAggregate( val streamTableAgg = new StreamPhysicalGroupTableAggregate(
cluster, cluster,
logicalTraits, logicalTraits,
studentLogicalScan, studentLogicalScan,
...@@ -844,7 +844,7 @@ class FlinkRelMdHandlerTestBase { ...@@ -844,7 +844,7 @@ class FlinkRelMdHandlerTestBase {
Seq(tableAggCall) Seq(tableAggCall)
) )
(logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg) (logicalTableAgg, flinkLogicalTableAgg, streamTableAgg)
} }
// equivalent Table API is // equivalent Table API is
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册