提交 c96d1fff 编写于 作者: G godfreyhe

[FLINK-20737][table-planner-blink] Introduce...

[FLINK-20737][table-planner-blink] Introduce StreamPhysicalGroupTableAggregate, and make StreamExecGroupTableAggregate only extended from ExecNode

This closes #14478
上级 1eaf54b7
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.stream;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction;
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;
import org.apache.calcite.rel.core.AggregateCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.Collections;
/** Stream {@link ExecNode} for unbounded java/scala group table aggregate. */
public class StreamExecGroupTableAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData> {
private static final Logger LOG = LoggerFactory.getLogger(StreamExecGroupTableAggregate.class);
private final int[] grouping;
private final AggregateCall[] aggCalls;
/** Each element indicates whether the corresponding agg call needs `retract` method. */
private final boolean[] aggCallNeedRetractions;
/** Whether this node will generate UPDATE_BEFORE messages. */
private final boolean generateUpdateBefore;
/** Whether this node consumes retraction messages. */
private final boolean needRetraction;
public StreamExecGroupTableAggregate(
int[] grouping,
AggregateCall[] aggCalls,
boolean[] aggCallNeedRetractions,
boolean generateUpdateBefore,
boolean needRetraction,
ExecEdge inputEdge,
RowType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
Preconditions.checkArgument(aggCalls.length == aggCallNeedRetractions.length);
this.grouping = grouping;
this.aggCalls = aggCalls;
this.aggCallNeedRetractions = aggCallNeedRetractions;
this.generateUpdateBefore = generateUpdateBefore;
this.needRetraction = needRetraction;
}
@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final TableConfig tableConfig = planner.getTableConfig();
if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime() < 0) {
LOG.warn(
"No state retention interval configured for a query which accumulates state. "
+ "Please provide a query configuration with valid retention interval to prevent excessive "
+ "state size. You may specify a retention time of 0 to not clean up the state.");
}
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner);
final RowType inputRowType = (RowType) inputNode.getOutputType();
final AggsHandlerCodeGenerator generator =
new AggsHandlerCodeGenerator(
new CodeGeneratorContext(tableConfig),
planner.getRelBuilder(),
JavaScalaConversionUtil.toScala(inputRowType.getChildren()),
// TODO: heap state backend do not copy key currently,
// we have to copy input field
// TODO: copy is not need when state backend is rocksdb,
// improve this in future
// TODO: but other operators do not copy this input field.....
true)
.needAccumulate();
if (needRetraction) {
generator.needRetract();
}
final AggregateInfoList aggInfoList =
AggregateUtil.transformToStreamAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
aggCallNeedRetractions,
needRetraction,
true,
true);
final GeneratedTableAggsHandleFunction aggsHandler =
generator.generateTableAggsHandler("GroupTableAggHandler", aggInfoList);
final LogicalType[] accTypes =
Arrays.stream(aggInfoList.getAccTypes())
.map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType)
.toArray(LogicalType[]::new);
final int inputCountIndex = aggInfoList.getIndexOfCountStar();
final GroupTableAggFunction aggFunction =
new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
tableConfig.getIdleStateRetention().toMillis());
final OneInputStreamOperator<RowData, RowData> operator =
new KeyedProcessOperator<>(aggFunction);
// partitioned aggregation
final OneInputTransformation<RowData, RowData> transform =
new OneInputTransformation<>(
inputTransform,
"GroupTableAggregate",
operator,
InternalTypeInfo.of(getOutputType()),
inputTransform.getParallelism());
if (inputsContainSingleton()) {
transform.setParallelism(1);
transform.setMaxParallelism(1);
}
// set KeyType and Selector for state
final RowDataKeySelector selector =
KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType));
transform.setStateKeySelector(selector);
transform.setStateKeyType(selector.getProducedType());
return transform;
}
}
......@@ -446,7 +446,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
* @return interval of the given column on stream group TableAggregate
*/
def getColumnInterval(
aggregate: StreamExecGroupTableAggregate,
aggregate: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
......@@ -550,7 +550,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: TableAggregate => agg.getGroupSet.toArray
case agg: StreamExecGroupTableAggregate => agg.grouping
case agg: StreamPhysicalGroupTableAggregate => agg.grouping
case agg: StreamExecGroupWindowTableAggregate => agg.getGrouping
}
......
......@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval
import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalLocalGroupAggregate}
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalGroupTableAggregate, StreamPhysicalLocalGroupAggregate}
import org.apache.flink.table.planner.plan.stats.ValueInterval
import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil
import org.apache.flink.util.Preconditions.checkArgument
......@@ -192,7 +192,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
}
def getFilteredColumnInterval(
aggregate: StreamExecGroupTableAggregate,
aggregate: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery,
columnIndex: Int,
filterArg: Int): ValueInterval = {
......
......@@ -282,7 +282,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}
def getRelModifiedMonotonicity(
rel: StreamExecGroupTableAggregate,
rel: StreamPhysicalGroupTableAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = {
getRelModifiedMonotonicityOnTableAggregate(
rel.getInput, rel.grouping, rel.getRowType.getFieldCount, mq)
......
......@@ -26,6 +26,7 @@ import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.delegation.StreamPlanner
import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, KeySelectorUtil}
import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec
......@@ -49,13 +50,14 @@ class StreamExecPythonGroupTableAggregate(
outputRowType: RelDataType,
grouping: Array[Int],
aggCalls: Seq[AggregateCall])
extends StreamExecGroupTableAggregateBase(
extends StreamPhysicalGroupTableAggregateBase(
cluster,
traitSet,
inputRel,
outputRowType,
grouping,
aggCalls)
with LegacyStreamExecNode[RowData]
with CommonExecPythonAggregate {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
......
......@@ -17,39 +17,29 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.stream
import org.apache.flink.api.dag.Transformation
import org.apache.flink.streaming.api.operators.KeyedProcessOperator
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator
import org.apache.flink.table.planner.delegation.StreamPlanner
import org.apache.flink.table.planner.plan.utils._
import org.apache.flink.table.runtime.operators.aggregate.GroupTableAggFunction
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecGroupTableAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ChangelogPlanUtils}
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.RelNode
import java.util
import scala.collection.JavaConversions._
/**
* Stream physical RelNode for unbounded java/scala group table aggregate.
*/
class StreamExecGroupTableAggregate(
* Stream physical RelNode for unbounded java/scala group table aggregate.
*/
class StreamPhysicalGroupTableAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
outputRowType: RelDataType,
grouping: Array[Int],
aggCalls: Seq[AggregateCall])
extends StreamExecGroupTableAggregateBase(
extends StreamPhysicalGroupTableAggregateBase(
cluster,
traitSet,
inputRel,
......@@ -58,7 +48,7 @@ class StreamExecGroupTableAggregate(
aggCalls) {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new StreamExecGroupTableAggregate(
new StreamPhysicalGroupTableAggregate(
cluster,
traitSet,
inputs.get(0),
......@@ -67,74 +57,20 @@ class StreamExecGroupTableAggregate(
aggCalls)
}
override protected def translateToPlanInternal(
planner: StreamPlanner): Transformation[RowData] = {
val tableConfig = planner.getTableConfig
if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) {
LOG.warn("No state retention interval configured for a query which accumulates state. " +
"Please provide a query configuration with valid retention interval to prevent excessive " +
"state size. You may specify a retention time of 0 to not clean up the state.")
}
val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType)
override def translateToExecNode(): ExecNode[_] = {
val aggCallNeedRetractions =
AggregateUtil.deriveAggCallNeedRetractions(this, grouping.length, aggCalls)
val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this)
val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this)
val generator = new AggsHandlerCodeGenerator(
CodeGeneratorContext(tableConfig),
planner.getRelBuilder,
inputRowType.getChildren,
// TODO: heap state backend do not copy key currently, we have to copy input field
// TODO: copy is not need when state backend is rocksdb, improve this in future
// TODO: but other operators do not copy this input field.....
copyInputField = true)
if (needRetraction) {
generator.needRetract()
}
val aggsHandler = generator
.needAccumulate()
.generateTableAggsHandler("GroupTableAggHandler", aggInfoList)
val accTypes = aggInfoList.getAccTypes.map(fromDataTypeToLogicalType)
val inputCountIndex = aggInfoList.getIndexOfCountStar
val aggFunction = new GroupTableAggFunction(
aggsHandler,
accTypes,
inputCountIndex,
generateUpdateBefore,
tableConfig.getIdleStateRetention.toMillis)
val operator = new KeyedProcessOperator[RowData, RowData, RowData](aggFunction)
val selector = KeySelectorUtil.getRowDataSelector(
new StreamExecGroupTableAggregate(
grouping,
InternalTypeInfo.of(inputRowType))
// partitioned aggregation
val ret = new OneInputTransformation(
inputTransformation,
"GroupTableAggregate",
operator,
InternalTypeInfo.of(outRowType),
inputTransformation.getParallelism)
if (inputsContainSingleton()) {
ret.setParallelism(1)
ret.setMaxParallelism(1)
}
// set KeyType and Selector for state
ret.setStateKeySelector(selector)
ret.setStateKeyType(selector.getProducedType)
ret
aggCalls.toArray,
aggCallNeedRetractions,
generateUpdateBefore,
needRetraction,
ExecEdge.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
)
}
}
......@@ -17,8 +17,6 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.stream
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode
import org.apache.flink.table.planner.plan.utils._
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
......@@ -29,7 +27,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
/**
* Base Stream physical RelNode for unbounded group table aggregate.
*/
abstract class StreamExecGroupTableAggregateBase(
abstract class StreamPhysicalGroupTableAggregateBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
......@@ -37,10 +35,9 @@ abstract class StreamExecGroupTableAggregateBase(
val grouping: Array[Int],
val aggCalls: Seq[AggregateCall])
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel
with LegacyStreamExecNode[RowData] {
with StreamPhysicalRel {
val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
protected val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList(
this,
grouping.length,
aggCalls)
......
......@@ -178,7 +178,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
val providedTrait = new ModifyKindSetTrait(builder.build())
createNewNode(agg, children, providedTrait, requiredTrait, requester)
case tagg: StreamExecGroupTableAggregateBase =>
case tagg: StreamPhysicalGroupTableAggregateBase =>
// table agg support all changes in input
val children = visitChildren(tagg, ModifyKindSetTrait.ALL_CHANGES)
// table aggregate will produce all changes, including deletions
......@@ -461,7 +461,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
}
visitSink(sink, sinkRequiredTraits)
case _: StreamPhysicalGroupAggregate | _: StreamExecGroupTableAggregate |
case _: StreamPhysicalGroupAggregate | _: StreamPhysicalGroupTableAggregate |
_: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate |
_: StreamExecPythonGroupTableAggregate =>
// Aggregate, TableAggregate and Limit requires update_before if there are updates
......
......@@ -421,7 +421,7 @@ object FlinkStreamRuleSets {
StreamPhysicalExpandRule.INSTANCE,
// group agg
StreamPhysicalGroupAggregateRule.INSTANCE,
StreamExecGroupTableAggregateRule.INSTANCE,
StreamPhysicalGroupTableAggregateRule.INSTANCE,
StreamPhysicalPythonGroupAggregateRule.INSTANCE,
StreamExecPythonGroupTableAggregateRule.INSTANCE,
// over agg
......
......@@ -18,22 +18,26 @@
package org.apache.flink.table.planner.plan.rules.physical.stream
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecGroupTableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGroupTableAggregate
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import scala.collection.JavaConversions._
class StreamExecGroupTableAggregateRule extends ConverterRule(
/**
* Rule to convert a [[FlinkLogicalTableAggregate]] into a [[StreamPhysicalGroupTableAggregate]].
*/
class StreamPhysicalGroupTableAggregateRule extends ConverterRule(
classOf[FlinkLogicalTableAggregate],
FlinkConventions.LOGICAL,
FlinkConventions.STREAM_PHYSICAL,
"StreamExecGroupTableAggregateRule") {
"StreamPhysicalGroupTableAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: FlinkLogicalTableAggregate = call.rel(0)
......@@ -53,7 +57,7 @@ class StreamExecGroupTableAggregateRule extends ConverterRule(
val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
val newInput: RelNode = RelOptRule.convert(agg.getInput, requiredTraitSet)
new StreamExecGroupTableAggregate(
new StreamPhysicalGroupTableAggregate(
rel.getCluster,
providedTraitSet,
newInput,
......@@ -64,7 +68,6 @@ class StreamExecGroupTableAggregateRule extends ConverterRule(
}
}
object StreamExecGroupTableAggregateRule {
val INSTANCE: StreamExecGroupTableAggregateRule = new StreamExecGroupTableAggregateRule()
object StreamPhysicalGroupTableAggregateRule {
val INSTANCE: StreamPhysicalGroupTableAggregateRule = new StreamPhysicalGroupTableAggregateRule()
}
......@@ -835,7 +835,7 @@ class FlinkRelMdHandlerTestBase {
builder.add("f1", new BasicSqlType(typeFactory.getTypeSystem, SqlTypeName.INTEGER))
val relDataType = builder.build()
val streamExecTableAgg = new StreamExecGroupTableAggregate(
val streamTableAgg = new StreamPhysicalGroupTableAggregate(
cluster,
logicalTraits,
studentLogicalScan,
......@@ -844,7 +844,7 @@ class FlinkRelMdHandlerTestBase {
Seq(tableAggCall)
)
(logicalTableAgg, flinkLogicalTableAgg, streamExecTableAgg)
(logicalTableAgg, flinkLogicalTableAgg, streamTableAgg)
}
// equivalent Table API is
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册