提交 9f8f5cd3 编写于 作者: G godfreyhe

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalLocalHashAggregate,...

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalLocalHashAggregate, and make BatchPhysicalLocalHashAggregate only extended from FlinkPhysicalRel

This closes #14562
上级 d0d8377b
......@@ -613,7 +613,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
agg.partialAggCalls(aggCallIndex)
case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchExecLocalHashAggregate =>
case agg: BatchPhysicalLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
......
......@@ -18,22 +18,12 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.MemorySize
import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, HashAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode}
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecHashAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.utils.RelExplainUtil
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelDistribution.Type
......@@ -51,7 +41,7 @@ import scala.collection.JavaConversions._
*
* @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/
class BatchExecLocalHashAggregate(
class BatchPhysicalLocalHashAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
......@@ -69,11 +59,10 @@ class BatchExecLocalHashAggregate(
auxGrouping,
aggCallToAggFunction,
isMerge = false,
isFinal = false)
with LegacyBatchExecNode[RowData] {
isFinal = false) {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalHashAggregate(
new BatchPhysicalLocalHashAggregate(
cluster,
traitSet,
inputs.get(0),
......@@ -130,57 +119,25 @@ class BatchExecLocalHashAggregate(
Some(copy(providedTraits, Seq(newInput)))
}
//~ ExecNode methods -----------------------------------------------------------
override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val config = planner.getTableConfig
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ctx = CodeGeneratorContext(config)
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList)
var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
managedMemory = MemorySize.parse(config.getConfiguration.getString(
ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes
new HashAggCodeGenerator(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal
).genWithKeys()
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
managedMemory)
override def translateToExecNode(): ExecNode[_] = {
new BatchExecHashAggregate(
grouping,
auxGrouping,
getAggCallList.toArray,
FlinkTypeFactory.toLogicalRowType(inputRowType),
false, // isMerge is always false
false, // isFinal is always false
getInputEdge,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
)
}
override def getInputEdges: util.List[ExecEdge] = {
private def getInputEdge: ExecEdge = {
if (grouping.length == 0) {
List(
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())
ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build()
} else {
List(ExecEdge.DEFAULT)
ExecEdge.DEFAULT
}
}
}
......@@ -22,7 +22,7 @@ import org.apache.flink.annotation.Experimental
import org.apache.flink.configuration.ConfigOption
import org.apache.flink.configuration.ConfigOptions.key
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecLocalHashAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate
import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.RelOptRule
......@@ -40,7 +40,7 @@ trait BatchExecJoinRuleBase {
val newInput = RelOptRule.convert(node, localRequiredTraitSet)
val providedTraitSet = localRequiredTraitSet
new BatchExecLocalHashAggregate(
new BatchPhysicalLocalHashAggregate(
node.getCluster,
providedTraitSet,
newInput,
......
......@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashAggregate, BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase, BatchPhysicalLocalHashAggregate}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.flink.table.planner.utils.AggregatePhaseStrategy
import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy
......@@ -209,7 +209,7 @@ trait BatchPhysicalAggRuleBase {
aggBufferTypes.map(_.map(fromDataTypeToLogicalType)))
if (isLocalHashAgg) {
new BatchExecLocalHashAggregate(
new BatchPhysicalLocalHashAggregate(
cluster,
traitSet,
input,
......
......@@ -39,7 +39,7 @@ import scala.collection.JavaConversions._
* {{{
* BatchPhysicalHashAggregate (global)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- BatchExecLocalHashAggregate (local)
* +- BatchPhysicalLocalHashAggregate (local)
* +- input of agg
* }}}
* when all aggregate functions are mergeable
......@@ -94,7 +94,7 @@ class BatchPhysicalHashAggRule
// create two-phase agg if possible
if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) {
// create BatchExecLocalHashAggregate
// create BatchPhysicalLocalHashAggregate
val localRequiredTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
val newInput = RelOptRule.convert(input, localRequiredTraitSet)
val providedTraitSet = localRequiredTraitSet
......
......@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalHashAggregate, BatchExecLocalHashAggregate}
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalHashAggregate, BatchPhysicalLocalHashAggregate}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
......@@ -31,13 +31,13 @@ import org.apache.calcite.rel.RelNode
*/
class RemoveRedundantLocalHashAggRule extends RelOptRule(
operand(classOf[BatchPhysicalHashAggregate],
operand(classOf[BatchExecLocalHashAggregate],
operand(classOf[BatchPhysicalLocalHashAggregate],
operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))),
"RemoveRedundantLocalHashAggRule") {
override def onMatch(call: RelOptRuleCall): Unit = {
val globalAgg: BatchPhysicalHashAggregate = call.rel(0)
val localAgg: BatchExecLocalHashAggregate = call.rel(1)
val localAgg: BatchPhysicalLocalHashAggregate = call.rel(1)
val inputOfLocalAgg = localAgg.getInput
val newGlobalAgg = new BatchPhysicalHashAggregate(
globalAgg.getCluster,
......
......@@ -979,7 +979,7 @@ class FlinkRelMdHandlerTestBase {
val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true)
val hash3 = FlinkRelDistribution.hash(Array(3), requireStrict = true)
val batchLocalAgg = new BatchExecLocalHashAggregate(
val batchLocalAgg = new BatchPhysicalLocalHashAggregate(
cluster,
batchPhysicalTraits,
studentBatchScan,
......@@ -1106,7 +1106,7 @@ class FlinkRelMdHandlerTestBase {
.add("sum_score", doubleType)
.add("cnt", longType).build()
val batchLocalAggWithAuxGroup = new BatchExecLocalHashAggregate(
val batchLocalAggWithAuxGroup = new BatchPhysicalLocalHashAggregate(
cluster,
batchPhysicalTraits,
studentBatchScan,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册