提交 146c68df 编写于 作者: G godfreyhe

[FLINK-20738][table-planner-blink] Rename BatchExecGroupAggregateBase to...

[FLINK-20738][table-planner-blink] Rename BatchExecGroupAggregateBase to BatchPhysicalGroupAggregateBase and do some refactoring

This closes #14562
上级 87260887
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.metadata package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.JDouble import org.apache.flink.table.planner.JDouble
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.stats._ import org.apache.flink.table.planner.plan.stats._
import org.apache.flink.table.planner.plan.utils.AggregateUtil import org.apache.flink.table.planner.plan.utils.AggregateUtil
...@@ -34,16 +34,16 @@ import org.apache.calcite.sql.{SqlKind, SqlOperator} ...@@ -34,16 +34,16 @@ import org.apache.calcite.sql.{SqlKind, SqlOperator}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Estimates selectivity of rows meeting an agg-call predicate on an Aggregate. * Estimates selectivity of rows meeting an agg-call predicate on an Aggregate.
* *
* A filter predicate on an Aggregate may contain two parts: * A filter predicate on an Aggregate may contain two parts:
* one is on group by columns, another is on aggregate call's result. * one is on group by columns, another is on aggregate call's result.
* The first part is handled by [[SelectivityEstimator]], * The first part is handled by [[SelectivityEstimator]],
* the second part is handled by this Estimator. * the second part is handled by this Estimator.
* *
* @param agg aggregate node * @param agg aggregate node
* @param mq Metadata query * @param mq Metadata query
*/ */
class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
extends RexVisitorImpl[Option[Double]](true) { extends RexVisitorImpl[Option[Double]](true) {
...@@ -53,15 +53,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -53,15 +53,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
private[flink] val defaultAggCallSelectivity = Some(0.01d) private[flink] val defaultAggCallSelectivity = Some(0.01d)
/** /**
* Gets AggregateCall from aggregate node * Gets AggregateCall from aggregate node
*/ */
def getSupportedAggCall(outputIdx: Int): Option[AggregateCall] = { def getSupportedAggCall(outputIdx: Int): Option[AggregateCall] = {
val (fullGrouping, aggCalls) = agg match { val (fullGrouping, aggCalls) = agg match {
case rel: Aggregate => case rel: Aggregate =>
val (auxGroupSet, otherAggCalls) = AggregateUtil.checkAndSplitAggCalls(rel) val (auxGroupSet, otherAggCalls) = AggregateUtil.checkAndSplitAggCalls(rel)
(rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls) (rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls)
case rel: BatchExecGroupAggregateBase => case rel: BatchPhysicalGroupAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList) (rel.grouping ++ rel.auxGrouping, rel.getAggCallList)
case rel: BatchExecLocalHashWindowAggregate => case rel: BatchExecLocalHashWindowAggregate =>
val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping
(fullGrouping, rel.getAggCallList) (fullGrouping, rel.getAggCallList)
...@@ -79,9 +79,9 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -79,9 +79,9 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns whether the given aggCall is supported now * Returns whether the given aggCall is supported now
* TODO supports more * TODO supports more
*/ */
def isSupportedAggCall(aggCall: AggregateCall): Boolean = { def isSupportedAggCall(aggCall: AggregateCall): Boolean = {
aggCall.getAggregation.getKind match { aggCall.getAggregation.getKind match {
case SqlKind.SUM | SqlKind.MAX | SqlKind.MIN | SqlKind.AVG => true case SqlKind.SUM | SqlKind.MAX | SqlKind.MIN | SqlKind.AVG => true
...@@ -91,8 +91,8 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -91,8 +91,8 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Gets aggCall's interval through its argument's interval. * Gets aggCall's interval through its argument's interval.
*/ */
def getAggCallInterval(aggCall: AggregateCall): ValueInterval = { def getAggCallInterval(aggCall: AggregateCall): ValueInterval = {
val aggInput = agg.getInput(0) val aggInput = agg.getInput(0)
...@@ -159,12 +159,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -159,12 +159,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting a filter predicate on aggregate. * Returns a percentage of rows meeting a filter predicate on aggregate.
* *
* @param predicate predicate whose selectivity is to be estimated against aggregate calls. * @param predicate predicate whose selectivity is to be estimated against aggregate calls.
* @return estimated selectivity (between 0.0 and 1.0), * @return estimated selectivity (between 0.0 and 1.0),
* or None if no reliable estimate can be determined. * or None if no reliable estimate can be determined.
*/ */
def evaluate(predicate: RexNode): Option[Double] = { def evaluate(predicate: RexNode): Option[Double] = {
try { try {
if (predicate == null) { if (predicate == null) {
...@@ -213,12 +213,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -213,12 +213,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting a single condition in Filter node. * Returns a percentage of rows meeting a single condition in Filter node.
* *
* @param singlePredicate predicate whose selectivity is to be estimated against aggregate calls. * @param singlePredicate predicate whose selectivity is to be estimated against aggregate calls.
* @return an optional double value to show the percentage of rows meeting a given condition. * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported. * It returns None if the condition is not supported.
*/ */
private def estimateSinglePredicate(singlePredicate: RexCall): Option[Double] = { private def estimateSinglePredicate(singlePredicate: RexCall): Option[Double] = {
val operands = singlePredicate.getOperands val operands = singlePredicate.getOperands
singlePredicate.getOperator match { singlePredicate.getOperator match {
...@@ -250,14 +250,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -250,14 +250,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting a binary comparison expression containing two columns. * Returns a percentage of rows meeting a binary comparison expression containing two columns.
* *
* @param op a binary comparison operator, including =, <=>, <, <=, >, >= * @param op a binary comparison operator, including =, <=>, <, <=, >, >=
* @param left the left RexInputRef * @param left the left RexInputRef
* @param right the right RexInputRef * @param right the right RexInputRef
* @return an optional double value to show the percentage of rows meeting a given condition. * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column. * It returns None if no statistics collected for a given column.
*/ */
private def estimateComparison(op: SqlOperator, left: RexNode, right: RexNode): Option[Double] = { private def estimateComparison(op: SqlOperator, left: RexNode, right: RexNode): Option[Double] = {
// if we can't handle some cases, uses SelectivityEstimator's default value // if we can't handle some cases, uses SelectivityEstimator's default value
// (consistent with normal case). // (consistent with normal case).
...@@ -302,14 +302,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -302,14 +302,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting an equality (=) expression. * Returns a percentage of rows meeting an equality (=) expression.
* e.g. count(a) = 10 * e.g. count(a) = 10
* *
* @param inputRef a RexInputRef * @param inputRef a RexInputRef
* @param literal a literal value (or constant) * @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition. * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column. * It returns None if no statistics collected for a given column.
*/ */
private def estimateEquals(inputRef: RexInputRef, literal: RexLiteral): Option[Double] = { private def estimateEquals(inputRef: RexInputRef, literal: RexLiteral): Option[Double] = {
if (literal.isNull) { if (literal.isNull) {
return se.defaultIsNullSelectivity return se.defaultIsNullSelectivity
...@@ -345,15 +345,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -345,15 +345,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting a binary comparison expression. * Returns a percentage of rows meeting a binary comparison expression.
* e.g. sum(a) > 10 * e.g. sum(a) > 10
* *
* @param op a binary comparison operator, including <, <=, >, >= * @param op a binary comparison operator, including <, <=, >, >=
* @param inputRef a RexInputRef * @param inputRef a RexInputRef
* @param literal a literal value (or constant) * @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition. * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column. * It returns None if no statistics collected for a given column.
*/ */
private def estimateComparison( private def estimateComparison(
op: SqlOperator, op: SqlOperator,
inputRef: RexInputRef, inputRef: RexInputRef,
...@@ -372,15 +372,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) ...@@ -372,15 +372,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
} }
/** /**
* Returns a percentage of rows meeting a binary numeric comparison expression. * Returns a percentage of rows meeting a binary numeric comparison expression.
* This method evaluate expression for Numeric/Boolean/Date/Time/Timestamp columns. * This method evaluate expression for Numeric/Boolean/Date/Time/Timestamp columns.
* *
* @param op a binary comparison operator, including <, <=, >, >= * @param op a binary comparison operator, including <, <=, >, >=
* @param aggCall an AggregateCall * @param aggCall an AggregateCall
* @param literal a literal value (or constant) * @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition. * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column. * It returns None if no statistics collected for a given column.
*/ */
private def estimateNumericComparison( private def estimateNumericComparison(
op: SqlOperator, op: SqlOperator,
aggCall: AggregateCall, aggCall: AggregateCall,
......
...@@ -329,12 +329,12 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata ...@@ -329,12 +329,12 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
} }
def areColumnsUnique( def areColumnsUnique(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
columns: ImmutableBitSet, columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = { ignoreNulls: Boolean): JBoolean = {
if (rel.isFinal) { if (rel.isFinal) {
areColumnsUniqueOnAggregate(rel.getGrouping, mq, columns, ignoreNulls) areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls)
} else { } else {
null null
} }
......
...@@ -309,7 +309,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata ...@@ -309,7 +309,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
} }
def getDistinctRowCount( def getDistinctRowCount(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
groupKey: ImmutableBitSet, groupKey: ImmutableBitSet,
predicate: RexNode): JDouble = { predicate: RexNode): JDouble = {
...@@ -397,7 +397,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata ...@@ -397,7 +397,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
predicate: RexNode): (Option[RexNode], Option[RexNode]) = agg match { predicate: RexNode): (Option[RexNode], Option[RexNode]) = agg match {
case rel: Aggregate => case rel: Aggregate =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase => case rel: BatchPhysicalGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecWindowAggregateBase => case rel: BatchExecWindowAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
......
...@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.metadata ...@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval
import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalGroupTableAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalGroupTableAggregate, StreamPhysicalLocalGroupAggregate}
import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.stats.ValueInterval
import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil
...@@ -176,7 +176,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC ...@@ -176,7 +176,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
} }
def getFilteredColumnInterval( def getFilteredColumnInterval(
aggregate: BatchExecGroupAggregateBase, aggregate: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
columnIndex: Int, columnIndex: Int,
filterArg: Int): ValueInterval = { filterArg: Int): ValueInterval = {
......
...@@ -25,7 +25,7 @@ import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity ...@@ -25,7 +25,7 @@ import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ModifiedMonotonicity import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ModifiedMonotonicity
import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate, WindowTableAggregate} import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate, WindowTableAggregate}
import org.apache.flink.table.planner.plan.nodes.logical._ import org.apache.flink.table.planner.plan.nodes.logical._
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchPhysicalCorrelate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCorrelate, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.nodes.physical.stream._ import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, TableSourceTable} import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, TableSourceTable}
import org.apache.flink.table.planner.plan.stats.{WithLower, WithUpper} import org.apache.flink.table.planner.plan.stats.{WithLower, WithUpper}
...@@ -51,9 +51,9 @@ import java.util.Collections ...@@ -51,9 +51,9 @@ import java.util.Collections
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* FlinkRelMdModifiedMonotonicity supplies a default implementation of * FlinkRelMdModifiedMonotonicity supplies a default implementation of
* [[FlinkRelMetadataQuery#getRelModifiedMonotonicity]] for logical algebra. * [[FlinkRelMetadataQuery#getRelModifiedMonotonicity]] for logical algebra.
*/ */
class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMonotonicity] { class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMonotonicity] {
override def getDef: MetadataDef[ModifiedMonotonicity] = FlinkMetadata.ModifiedMonotonicity.DEF override def getDef: MetadataDef[ModifiedMonotonicity] = FlinkMetadata.ModifiedMonotonicity.DEF
...@@ -239,8 +239,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -239,8 +239,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
rel: StreamPhysicalMiniBatchAssigner, rel: StreamPhysicalMiniBatchAssigner,
mq: RelMetadataQuery): RelModifiedMonotonicity = { mq: RelMetadataQuery): RelModifiedMonotonicity = {
getMonotonicity(rel.getInput, mq, rel.getRowType.getFieldCount) getMonotonicity(rel.getInput, mq, rel.getRowType.getFieldCount)
} }
...@@ -256,8 +256,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -256,8 +256,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
rel: WindowTableAggregate, rel: WindowTableAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = { mq: RelMetadataQuery): RelModifiedMonotonicity = {
if (allAppend(mq, rel.getInput)) { if (allAppend(mq, rel.getInput)) {
constants(rel.getRowType.getFieldCount) constants(rel.getRowType.getFieldCount)
} else { } else {
...@@ -272,7 +272,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -272,7 +272,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery): RelModifiedMonotonicity = null mq: RelMetadataQuery): RelModifiedMonotonicity = null
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
...@@ -324,8 +324,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -324,8 +324,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
def getRelModifiedMonotonicity( def getRelModifiedMonotonicity(
rel: StreamExecGroupWindowTableAggregate, rel: StreamExecGroupWindowTableAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = { mq: RelMetadataQuery): RelModifiedMonotonicity = {
if (allAppend(mq, rel.getInput)) { if (allAppend(mq, rel.getInput)) {
constants(rel.getRowType.getFieldCount) constants(rel.getRowType.getFieldCount)
} else { } else {
...@@ -546,9 +546,9 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -546,9 +546,9 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
def getRelModifiedMonotonicity(rel: RelNode, mq: RelMetadataQuery): RelModifiedMonotonicity = null def getRelModifiedMonotonicity(rel: RelNode, mq: RelMetadataQuery): RelModifiedMonotonicity = null
/** /**
* Utility to create a RelModifiedMonotonicity which all fields is modified constant which * Utility to create a RelModifiedMonotonicity which all fields is modified constant which
* means all the field's value will not be modified. * means all the field's value will not be modified.
*/ */
def constants(fieldCount: Int): RelModifiedMonotonicity = { def constants(fieldCount: Int): RelModifiedMonotonicity = {
new RelModifiedMonotonicity(Array.fill(fieldCount)(CONSTANT)) new RelModifiedMonotonicity(Array.fill(fieldCount)(CONSTANT))
} }
...@@ -558,8 +558,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon ...@@ -558,8 +558,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
} }
/** /**
* These operator won't generate update itself * These operator won't generate update itself
*/ */
def getMonotonicity( def getMonotonicity(
input: RelNode, input: RelNode,
mq: RelMetadataQuery, mq: RelMetadataQuery,
......
...@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata ...@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.JDouble import org.apache.flink.table.planner.JDouble
import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank} import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase
import org.apache.calcite.plan.volcano.RelSubset import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
...@@ -51,7 +51,9 @@ class FlinkRelMdPercentageOriginalRows private ...@@ -51,7 +51,9 @@ class FlinkRelMdPercentageOriginalRows private
def getPercentageOriginalRows(rel: Aggregate, mq: RelMetadataQuery): JDouble = def getPercentageOriginalRows(rel: Aggregate, mq: RelMetadataQuery): JDouble =
mq.getPercentageOriginalRows(rel.getInput) mq.getPercentageOriginalRows(rel.getInput)
def getPercentageOriginalRows(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JDouble = { def getPercentageOriginalRows(
rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery): JDouble = {
mq.getPercentageOriginalRows(rel.getInput) mq.getPercentageOriginalRows(rel.getInput)
} }
......
...@@ -218,7 +218,7 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P ...@@ -218,7 +218,7 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P
} }
def getPopulationSize( def getPopulationSize(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = { groupKey: ImmutableBitSet): JDouble = {
// for global agg which has inner local agg, it passes the parameters to input directly // for global agg which has inner local agg, it passes the parameters to input directly
......
...@@ -138,15 +138,15 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun ...@@ -138,15 +138,15 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
} }
} }
def getRowCount(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JDouble = { def getRowCount(rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery): JDouble = {
getRowCountOfBatchExecAgg(rel, mq) getRowCountOfBatchExecAgg(rel, mq)
} }
private def getRowCountOfBatchExecAgg(rel: SingleRel, mq: RelMetadataQuery): JDouble = { private def getRowCountOfBatchExecAgg(rel: SingleRel, mq: RelMetadataQuery): JDouble = {
val input = rel.getInput val input = rel.getInput
val (grouping, isFinal, isMerge) = rel match { val (grouping, isFinal, isMerge) = rel match {
case agg: BatchExecGroupAggregateBase => case agg: BatchPhysicalGroupAggregateBase =>
(ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge) (ImmutableBitSet.of(agg.grouping: _*), agg.isFinal, agg.isMerge)
case windowAgg: BatchExecWindowAggregateBase => case windowAgg: BatchExecWindowAggregateBase =>
(ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge) (ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge)
case _ => throw new IllegalArgumentException(s"Unknown aggregate type ${rel.getRelTypeName}!") case _ => throw new IllegalArgumentException(s"Unknown aggregate type ${rel.getRelTypeName}!")
......
...@@ -97,7 +97,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele ...@@ -97,7 +97,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate) predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate)
def getSelectivity( def getSelectivity(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate) predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate)
...@@ -130,7 +130,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele ...@@ -130,7 +130,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
} else { } else {
val hasLocalAgg = agg match { val hasLocalAgg = agg match {
case _: Aggregate => false case _: Aggregate => false
case rel: BatchExecGroupAggregateBase => rel.isFinal && rel.isMerge case rel: BatchPhysicalGroupAggregateBase => rel.isFinal && rel.isMerge
case rel: BatchExecWindowAggregateBase => rel.isFinal && rel.isMerge case rel: BatchExecWindowAggregateBase => rel.isFinal && rel.isMerge
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!") case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
} }
...@@ -147,7 +147,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele ...@@ -147,7 +147,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
val (childPred, restPred) = agg match { val (childPred, restPred) = agg match {
case rel: Aggregate => case rel: Aggregate =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase => case rel: BatchPhysicalGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecWindowAggregateBase => case rel: BatchExecWindowAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
......
...@@ -187,11 +187,13 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] { ...@@ -187,11 +187,13 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] {
sizesBuilder.build sizesBuilder.build
} }
def averageColumnSizes(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JList[JDouble] = { def averageColumnSizes(
rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery): JList[JDouble] = {
// note: the logical to estimate column sizes of AggregateBatchExecBase is different from // note: the logical to estimate column sizes of AggregateBatchExecBase is different from
// Calcite Aggregate because AggregateBatchExecBase's rowTypes is not composed by // Calcite Aggregate because AggregateBatchExecBase's rowTypes is not composed by
// grouping columns + aggFunctionCall results // grouping columns + aggFunctionCall results
val mapInputToOutput = (rel.getGrouping ++ rel.getAuxGrouping).zipWithIndex.toMap val mapInputToOutput = (rel.grouping ++ rel.auxGrouping).zipWithIndex.toMap
getColumnSizesFromInputOrType(rel, mq, mapInputToOutput) getColumnSizesFromInputOrType(rel, mq, mapInputToOutput)
} }
......
...@@ -211,10 +211,10 @@ class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] { ...@@ -211,10 +211,10 @@ class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] {
} }
def getUniqueGroups( def getUniqueGroups(
agg: BatchExecGroupAggregateBase, agg: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = { columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = agg.getGrouping val grouping = agg.grouping
getUniqueGroupsOfAggregate(agg.getRowType.getFieldCount, grouping, agg.getInput, mq, columns) getUniqueGroupsOfAggregate(agg.getRowType.getFieldCount, grouping, agg.getInput, mq, columns)
} }
......
...@@ -84,7 +84,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu ...@@ -84,7 +84,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
columns.indexOf(c) columns.indexOf(c)
} }
val builder = ImmutableSet.builder[ImmutableBitSet]() val builder = ImmutableSet.builder[ImmutableBitSet]()
builder.add(ImmutableBitSet.of(columnIndices:_*)) builder.add(ImmutableBitSet.of(columnIndices: _*))
val uniqueSet = sourceTable.uniqueKeysSet().orElse(null) val uniqueSet = sourceTable.uniqueKeysSet().orElse(null)
if (uniqueSet != null) { if (uniqueSet != null) {
builder.addAll(uniqueSet) builder.addAll(uniqueSet)
...@@ -198,8 +198,8 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu ...@@ -198,8 +198,8 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
} }
/** /**
* Whether the [[RexCall]] is a cast that doesn't lose any information. * Whether the [[RexCall]] is a cast that doesn't lose any information.
*/ */
private def isFidelityCast(call: RexCall): Boolean = { private def isFidelityCast(call: RexCall): Boolean = {
if (call.getKind != SqlKind.CAST) { if (call.getKind != SqlKind.CAST) {
return false return false
...@@ -334,11 +334,11 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu ...@@ -334,11 +334,11 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
} }
def getUniqueKeys( def getUniqueKeys(
rel: BatchExecGroupAggregateBase, rel: BatchPhysicalGroupAggregateBase,
mq: RelMetadataQuery, mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = { ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
if (rel.isFinal) { if (rel.isFinal) {
getUniqueKeysOnAggregate(rel.getGrouping, mq, ignoreNulls) getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
} else { } else {
null null
} }
......
...@@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} ...@@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON}
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.{ImmutableIntList, Util} import org.apache.calcite.util.{ImmutableIntList, Util}
import java.util import java.util
...@@ -39,11 +38,10 @@ import scala.collection.JavaConversions._ ...@@ -39,11 +38,10 @@ import scala.collection.JavaConversions._
/** /**
* Batch physical RelNode for (global) hash-based aggregate operator. * Batch physical RelNode for (global) hash-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
class BatchExecHashAggregate( class BatchExecHashAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -55,7 +53,6 @@ class BatchExecHashAggregate( ...@@ -55,7 +53,6 @@ class BatchExecHashAggregate(
isMerge: Boolean) isMerge: Boolean)
extends BatchExecHashAggregateBase( extends BatchExecHashAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
...@@ -70,7 +67,6 @@ class BatchExecHashAggregate( ...@@ -70,7 +67,6 @@ class BatchExecHashAggregate(
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecHashAggregate( new BatchExecHashAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
outputRowType, outputRowType,
......
...@@ -40,17 +40,15 @@ import org.apache.calcite.rel.RelNode ...@@ -40,17 +40,15 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.Util import org.apache.calcite.util.Util
/** /**
* Batch physical RelNode for hash-based aggregate operator. * Batch physical RelNode for hash-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
abstract class BatchExecHashAggregateBase( abstract class BatchExecHashAggregateBase(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -61,13 +59,11 @@ abstract class BatchExecHashAggregateBase( ...@@ -61,13 +59,11 @@ abstract class BatchExecHashAggregateBase(
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
isMerge: Boolean, isMerge: Boolean,
isFinal: Boolean) isFinal: Boolean)
extends BatchExecGroupAggregateBase( extends BatchPhysicalGroupAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
inputRowType,
grouping, grouping,
auxGrouping, auxGrouping,
aggCallToAggFunction, aggCallToAggFunction,
...@@ -114,17 +110,25 @@ abstract class BatchExecHashAggregateBase( ...@@ -114,17 +110,25 @@ abstract class BatchExecHashAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList)
var managedMemory: Long = 0L var managedMemory: Long = 0L
val generatedOperator = if (grouping.isEmpty) { val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys( AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, relBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else { } else {
managedMemory = MemorySize.parse(config.getConfiguration.getString( managedMemory = MemorySize.parse(config.getConfiguration.getString(
ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes
new HashAggCodeGenerator( new HashAggCodeGenerator(
ctx, relBuilder, aggInfos, inputType, outputType, grouping, auxGrouping, isMerge, isFinal ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal
).genWithKeys() ).genWithKeys()
} }
val operator = new CodeGenOperatorFactory[RowData](generatedOperator) val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
......
...@@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type ...@@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.ImmutableIntList import org.apache.calcite.util.ImmutableIntList
import java.util import java.util
...@@ -36,13 +35,12 @@ import java.util ...@@ -36,13 +35,12 @@ import java.util
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Batch physical RelNode for local hash-based aggregate operator. * Batch physical RelNode for local hash-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
class BatchExecLocalHashAggregate( class BatchExecLocalHashAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -52,7 +50,6 @@ class BatchExecLocalHashAggregate( ...@@ -52,7 +50,6 @@ class BatchExecLocalHashAggregate(
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)]) aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
extends BatchExecHashAggregateBase( extends BatchExecHashAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
...@@ -67,7 +64,6 @@ class BatchExecLocalHashAggregate( ...@@ -67,7 +64,6 @@ class BatchExecLocalHashAggregate(
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalHashAggregate( new BatchExecLocalHashAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
outputRowType, outputRowType,
......
...@@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type ...@@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type
import org.apache.calcite.rel._ import org.apache.calcite.rel._
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.ImmutableIntList import org.apache.calcite.util.ImmutableIntList
import java.util import java.util
...@@ -36,13 +35,12 @@ import java.util ...@@ -36,13 +35,12 @@ import java.util
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Batch physical RelNode for local sort-based aggregate operator. * Batch physical RelNode for local sort-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
class BatchExecLocalSortAggregate( class BatchExecLocalSortAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -52,7 +50,6 @@ class BatchExecLocalSortAggregate( ...@@ -52,7 +50,6 @@ class BatchExecLocalSortAggregate(
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)]) aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
extends BatchExecSortAggregateBase( extends BatchExecSortAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
...@@ -67,7 +64,6 @@ class BatchExecLocalSortAggregate( ...@@ -67,7 +64,6 @@ class BatchExecLocalSortAggregate(
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalSortAggregate( new BatchExecLocalSortAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
outputRowType, outputRowType,
......
...@@ -49,8 +49,8 @@ import java.util ...@@ -49,8 +49,8 @@ import java.util
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Batch physical RelNode for aggregate (Python user defined aggregate function). * Batch physical RelNode for aggregate (Python user defined aggregate function).
*/ */
class BatchExecPythonGroupAggregate( class BatchExecPythonGroupAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
...@@ -62,13 +62,11 @@ class BatchExecPythonGroupAggregate( ...@@ -62,13 +62,11 @@ class BatchExecPythonGroupAggregate(
auxGrouping: Array[Int], auxGrouping: Array[Int],
aggCalls: Seq[AggregateCall], aggCalls: Seq[AggregateCall],
aggFunctions: Array[UserDefinedFunction]) aggFunctions: Array[UserDefinedFunction])
extends BatchExecGroupAggregateBase( extends BatchPhysicalGroupAggregateBase(
cluster, cluster,
null,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
inputRowType,
grouping, grouping,
auxGrouping, auxGrouping,
aggCalls.zip(aggFunctions), aggCalls.zip(aggFunctions),
......
...@@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} ...@@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON}
import org.apache.calcite.rel._ import org.apache.calcite.rel._
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.{ImmutableIntList, Util} import org.apache.calcite.util.{ImmutableIntList, Util}
import java.util import java.util
...@@ -37,13 +36,12 @@ import java.util ...@@ -37,13 +36,12 @@ import java.util
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Batch physical RelNode for (global) sort-based aggregate operator. * Batch physical RelNode for (global) sort-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
class BatchExecSortAggregate( class BatchExecSortAggregate(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -55,7 +53,6 @@ class BatchExecSortAggregate( ...@@ -55,7 +53,6 @@ class BatchExecSortAggregate(
isMerge: Boolean) isMerge: Boolean)
extends BatchExecSortAggregateBase( extends BatchExecSortAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
...@@ -70,7 +67,6 @@ class BatchExecSortAggregate( ...@@ -70,7 +67,6 @@ class BatchExecSortAggregate(
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecSortAggregate( new BatchExecSortAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputs.get(0), inputs.get(0),
outputRowType, outputRowType,
......
...@@ -36,16 +36,14 @@ import org.apache.calcite.rel.RelNode ...@@ -36,16 +36,14 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.tools.RelBuilder
/** /**
* Batch physical RelNode for sort-based aggregate operator. * Batch physical RelNode for sort-based aggregate operator.
* *
* @see [[BatchExecGroupAggregateBase]] for more info. * @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/ */
abstract class BatchExecSortAggregateBase( abstract class BatchExecSortAggregateBase(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -56,13 +54,11 @@ abstract class BatchExecSortAggregateBase( ...@@ -56,13 +54,11 @@ abstract class BatchExecSortAggregateBase(
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
isMerge: Boolean, isMerge: Boolean,
isFinal: Boolean) isFinal: Boolean)
extends BatchExecGroupAggregateBase( extends BatchPhysicalGroupAggregateBase(
cluster, cluster,
relBuilder,
traitSet, traitSet,
inputRel, inputRel,
outputRowType, outputRowType,
inputRowType,
grouping, grouping,
auxGrouping, auxGrouping,
aggCallToAggFunction, aggCallToAggFunction,
...@@ -95,14 +91,22 @@ abstract class BatchExecSortAggregateBase( ...@@ -95,14 +91,22 @@ abstract class BatchExecSortAggregateBase(
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val aggInfos = transformToBatchAggregateInfoList( val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList)
val generatedOperator = if (grouping.isEmpty) { val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys( AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, relBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else { } else {
SortAggCodeGenerator.genWithKeys( SortAggCodeGenerator.genWithKeys(
ctx, relBuilder, aggInfos, inputType, outputType, grouping, auxGrouping, isMerge, isFinal) ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal)
} }
val operator = new CodeGenOperatorFactory[RowData](generatedOperator) val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation( ExecNodeUtil.createOneInputTransformation(
......
...@@ -28,32 +28,29 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -28,32 +28,29 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.{RelNode, SingleRel} import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.tools.RelBuilder
/** /**
* Batch physical RelNode for aggregate. * Batch physical RelNode for aggregate.
* *
* <P>There are two differences between this node and [[Aggregate]]: * <P>There are two differences between this node and [[Aggregate]]:
* 1. This node supports two-stage aggregation to reduce data-shuffling: * 1. This node supports two-stage aggregation to reduce data-shuffling:
* local-aggregation and global-aggregation. * local-aggregation and global-aggregation.
* local-aggregation produces a partial result for each group before shuffle in stage 1, * local-aggregation produces a partial result for each group before shuffle in stage 1,
* and then the partially aggregated results are shuffled to global-aggregation * and then the partially aggregated results are shuffled to global-aggregation
* which produces the final result in stage 2. * which produces the final result in stage 2.
* Two-stage aggregation is enabled only if all aggregate functions are mergeable. * Two-stage aggregation is enabled only if all aggregate functions are mergeable.
* (e.g. SUM, AVG, MAX) * (e.g. SUM, AVG, MAX)
* 2. This node supports auxiliary group keys which will not be computed as key and * 2. This node supports auxiliary group keys which will not be computed as key and
* does not also affect the correctness of the final result. [[Aggregate]] does not distinguish * does not also affect the correctness of the final result. [[Aggregate]] does not distinguish
* group keys and auxiliary group keys, and combines them as a complete `groupSet`. * group keys and auxiliary group keys, and combines them as a complete `groupSet`.
*/ */
abstract class BatchExecGroupAggregateBase( abstract class BatchPhysicalGroupAggregateBase(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
outputRowType: RelDataType, outputRowType: RelDataType,
inputRowType: RelDataType, val grouping: Array[Int],
grouping: Array[Int], val auxGrouping: Array[Int],
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
val isMerge: Boolean, val isMerge: Boolean,
val isFinal: Boolean) val isFinal: Boolean)
...@@ -66,10 +63,6 @@ abstract class BatchExecGroupAggregateBase( ...@@ -66,10 +63,6 @@ abstract class BatchExecGroupAggregateBase(
override def deriveRowType(): RelDataType = outputRowType override def deriveRowType(): RelDataType = outputRowType
def getGrouping: Array[Int] = grouping
def getAuxGrouping: Array[Int] = auxGrouping
def getAggCallList: Seq[AggregateCall] = aggCallToAggFunction.map(_._1) def getAggCallList: Seq[AggregateCall] = aggCallToAggFunction.map(_._1)
def getAggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)] = aggCallToAggFunction def getAggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)] = aggCallToAggFunction
......
...@@ -60,7 +60,7 @@ class BatchExecHashAggRule ...@@ -60,7 +60,7 @@ class BatchExecHashAggRule
operand(classOf[FlinkLogicalAggregate], operand(classOf[FlinkLogicalAggregate],
operand(classOf[RelNode], any)), operand(classOf[RelNode], any)),
"BatchExecHashAggRule") "BatchExecHashAggRule")
with BatchExecAggRuleBase { with BatchPhysicalAggRuleBase {
override def matches(call: RelOptRuleCall): Boolean = { override def matches(call: RelOptRuleCall): Boolean = {
val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
...@@ -100,7 +100,6 @@ class BatchExecHashAggRule ...@@ -100,7 +100,6 @@ class BatchExecHashAggRule
val providedTraitSet = localRequiredTraitSet val providedTraitSet = localRequiredTraitSet
val localHashAgg = createLocalAgg( val localHashAgg = createLocalAgg(
agg.getCluster, agg.getCluster,
call.builder(),
providedTraitSet, providedTraitSet,
newInput, newInput,
agg.getRowType, agg.getRowType,
...@@ -136,7 +135,6 @@ class BatchExecHashAggRule ...@@ -136,7 +135,6 @@ class BatchExecHashAggRule
val newLocalHashAgg = RelOptRule.convert(localHashAgg, requiredTraitSet) val newLocalHashAgg = RelOptRule.convert(localHashAgg, requiredTraitSet)
val globalHashAgg = new BatchExecHashAggregate( val globalHashAgg = new BatchExecHashAggregate(
agg.getCluster, agg.getCluster,
call.builder(),
aggProvidedTraitSet, aggProvidedTraitSet,
newLocalHashAgg, newLocalHashAgg,
agg.getRowType, agg.getRowType,
...@@ -167,7 +165,6 @@ class BatchExecHashAggRule ...@@ -167,7 +165,6 @@ class BatchExecHashAggRule
val newInput = RelOptRule.convert(input, requiredTraitSet) val newInput = RelOptRule.convert(input, requiredTraitSet)
val hashAgg = new BatchExecHashAggregate( val hashAgg = new BatchExecHashAggregate(
agg.getCluster, agg.getCluster,
call.builder(),
aggProvidedTraitSet, aggProvidedTraitSet,
newInput, newInput,
agg.getRowType, agg.getRowType,
......
...@@ -85,7 +85,7 @@ class BatchExecHashJoinRule ...@@ -85,7 +85,7 @@ class BatchExecHashJoinRule
val distinctKeys = 0 until join.getRight.getRowType.getFieldCount val distinctKeys = 0 until join.getRight.getRowType.getFieldCount
val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys) val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys)
if (useBuildDistinct) { if (useBuildDistinct) {
(addLocalDistinctAgg(join.getRight, distinctKeys, call.builder()), true) (addLocalDistinctAgg(join.getRight, distinctKeys), true)
} else { } else {
(join.getRight, false) (join.getRight, false)
} }
......
...@@ -27,7 +27,6 @@ import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUti ...@@ -27,7 +27,6 @@ import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUti
import org.apache.calcite.plan.RelOptRule import org.apache.calcite.plan.RelOptRule
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.ImmutableBitSet import org.apache.calcite.util.ImmutableBitSet
import java.lang.{Boolean => JBoolean, Double => JDouble} import java.lang.{Boolean => JBoolean, Double => JDouble}
...@@ -36,15 +35,13 @@ trait BatchExecJoinRuleBase { ...@@ -36,15 +35,13 @@ trait BatchExecJoinRuleBase {
def addLocalDistinctAgg( def addLocalDistinctAgg(
node: RelNode, node: RelNode,
distinctKeys: Seq[Int], distinctKeys: Seq[Int]): RelNode = {
relBuilder: RelBuilder): RelNode = {
val localRequiredTraitSet = node.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) val localRequiredTraitSet = node.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
val newInput = RelOptRule.convert(node, localRequiredTraitSet) val newInput = RelOptRule.convert(node, localRequiredTraitSet)
val providedTraitSet = localRequiredTraitSet val providedTraitSet = localRequiredTraitSet
new BatchExecLocalHashAggregate( new BatchExecLocalHashAggregate(
node.getCluster, node.getCluster,
relBuilder,
providedTraitSet, providedTraitSet,
newInput, newInput,
node.getRowType, // output row type node.getRowType, // output row type
......
...@@ -54,7 +54,7 @@ class BatchExecNestedLoopJoinRule ...@@ -54,7 +54,7 @@ class BatchExecNestedLoopJoinRule
val distinctKeys = 0 until join.getRight.getRowType.getFieldCount val distinctKeys = 0 until join.getRight.getRowType.getFieldCount
val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys) val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys)
if (useBuildDistinct) { if (useBuildDistinct) {
addLocalDistinctAgg(join.getRight, distinctKeys, call.builder()) addLocalDistinctAgg(join.getRight, distinctKeys)
} else { } else {
join.getRight join.getRight
} }
......
...@@ -34,35 +34,35 @@ import org.apache.calcite.rel._ ...@@ -34,35 +34,35 @@ import org.apache.calcite.rel._
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Rule that converts [[FlinkLogicalAggregate]] to * Rule that converts [[FlinkLogicalAggregate]] to
* {{{ * {{{
* BatchExecSortAggregate (global) * BatchExecSortAggregate (global)
* +- Sort (exists if group keys are not empty) * +- Sort (exists if group keys are not empty)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- BatchExecLocalSortAggregate (local) * +- BatchExecLocalSortAggregate (local)
* +- Sort (exists if group keys are not empty) * +- Sort (exists if group keys are not empty)
* +- input of agg * +- input of agg
* }}} * }}}
* when all aggregate functions are mergeable * when all aggregate functions are mergeable
* and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or
* {{{ * {{{
* BatchExecSortAggregate * BatchExecSortAggregate
* +- Sort (exists if group keys are not empty) * +- Sort (exists if group keys are not empty)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- input of agg * +- input of agg
* }}} * }}}
* when some aggregate functions are not mergeable * when some aggregate functions are not mergeable
* or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE.
* *
* Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE,
* this rule will try to create two possibilities above, and chooses the best one based on cost. * this rule will try to create two possibilities above, and chooses the best one based on cost.
*/ */
class BatchExecSortAggRule class BatchExecSortAggRule
extends RelOptRule( extends RelOptRule(
operand(classOf[FlinkLogicalAggregate], operand(classOf[FlinkLogicalAggregate],
operand(classOf[RelNode], any)), operand(classOf[RelNode], any)),
"BatchExecSortAggRule") "BatchExecSortAggRule")
with BatchExecAggRuleBase { with BatchPhysicalAggRuleBase {
override def matches(call: RelOptRuleCall): Boolean = { override def matches(call: RelOptRuleCall): Boolean = {
val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
...@@ -99,7 +99,6 @@ class BatchExecSortAggRule ...@@ -99,7 +99,6 @@ class BatchExecSortAggRule
val localSortAgg = createLocalAgg( val localSortAgg = createLocalAgg(
agg.getCluster, agg.getCluster,
call.builder(),
providedLocalTraitSet, providedLocalTraitSet,
newLocalInput, newLocalInput,
agg.getRowType, agg.getRowType,
...@@ -142,7 +141,6 @@ class BatchExecSortAggRule ...@@ -142,7 +141,6 @@ class BatchExecSortAggRule
val newInputForFinalAgg = RelOptRule.convert(localSortAgg, requiredTraitSet) val newInputForFinalAgg = RelOptRule.convert(localSortAgg, requiredTraitSet)
val globalSortAgg = new BatchExecSortAggregate( val globalSortAgg = new BatchExecSortAggregate(
agg.getCluster, agg.getCluster,
call.builder(),
aggProvidedTraitSet, aggProvidedTraitSet,
newInputForFinalAgg, newInputForFinalAgg,
agg.getRowType, agg.getRowType,
...@@ -177,7 +175,6 @@ class BatchExecSortAggRule ...@@ -177,7 +175,6 @@ class BatchExecSortAggRule
val newInput = RelOptRule.convert(input, requiredTraitSet) val newInput = RelOptRule.convert(input, requiredTraitSet)
val sortAgg = new BatchExecSortAggregate( val sortAgg = new BatchExecSortAggregate(
agg.getCluster, agg.getCluster,
call.builder(),
aggProvidedTraitSet, aggProvidedTraitSet,
newInput, newInput,
agg.getRowType, agg.getRowType,
......
...@@ -46,34 +46,34 @@ import org.apache.commons.math3.util.ArithmeticUtils ...@@ -46,34 +46,34 @@ import org.apache.commons.math3.util.ArithmeticUtils
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Rule to convert a [[FlinkLogicalWindowAggregate]] into a * Rule to convert a [[FlinkLogicalWindowAggregate]] into a
* {{{ * {{{
* BatchExecHash(or Sort)WindowAggregate (global) * BatchExecHash(or Sort)WindowAggregate (global)
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- BatchExecLocalHash(or Sort)WindowAggregate (local) * +- BatchExecLocalHash(or Sort)WindowAggregate (local)
* +- input of window agg * +- input of window agg
* }}} * }}}
* when all aggregate functions are mergeable * when all aggregate functions are mergeable
* and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or
* {{{ * {{{
* BatchExecHash(or Sort)WindowAggregate * BatchExecHash(or Sort)WindowAggregate
* +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton)
* +- input of window agg * +- input of window agg
* }}} * }}}
* when some aggregate functions are not mergeable * when some aggregate functions are not mergeable
* or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE.
* *
* Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE,
* this rule will try to create two possibilities above, and chooses the best one based on cost. * this rule will try to create two possibilities above, and chooses the best one based on cost.
* if all aggregate function buffer are fix length, the rule will choose hash window agg. * if all aggregate function buffer are fix length, the rule will choose hash window agg.
*/ */
class BatchExecWindowAggregateRule class BatchExecWindowAggregateRule
extends RelOptRule( extends RelOptRule(
operand(classOf[FlinkLogicalWindowAggregate], operand(classOf[FlinkLogicalWindowAggregate],
operand(classOf[RelNode], any)), operand(classOf[RelNode], any)),
FlinkRelFactories.LOGICAL_BUILDER_WITHOUT_AGG_INPUT_PRUNE, FlinkRelFactories.LOGICAL_BUILDER_WITHOUT_AGG_INPUT_PRUNE,
"BatchExecWindowAggregateRule") "BatchExecWindowAggregateRule")
with BatchExecAggRuleBase { with BatchPhysicalAggRuleBase {
override def matches(call: RelOptRuleCall): Boolean = { override def matches(call: RelOptRuleCall): Boolean = {
val agg: FlinkLogicalWindowAggregate = call.rel(0) val agg: FlinkLogicalWindowAggregate = call.rel(0)
...@@ -346,11 +346,11 @@ class BatchExecWindowAggregateRule ...@@ -346,11 +346,11 @@ class BatchExecWindowAggregateRule
} }
/** /**
* Return true when sliding window with slideSize < windowSize && gcd(windowSize, slideSize) > 1. * Return true when sliding window with slideSize < windowSize && gcd(windowSize, slideSize) > 1.
* Otherwise return false, including the cases of tumbling window, * Otherwise return false, including the cases of tumbling window,
* sliding window with slideSize >= windowSize and * sliding window with slideSize >= windowSize and
* sliding window with slideSize < windowSize but gcd(windowSize, slideSize) == 1. * sliding window with slideSize < windowSize but gcd(windowSize, slideSize) == 1.
*/ */
private def useAssignPane( private def useAssignPane(
aggregateList: Array[UserDefinedFunction], aggregateList: Array[UserDefinedFunction],
windowSize: Long, windowSize: Long,
...@@ -360,12 +360,12 @@ class BatchExecWindowAggregateRule ...@@ -360,12 +360,12 @@ class BatchExecWindowAggregateRule
} }
/** /**
* In the case of sliding window without the optimization of assigning pane which means * In the case of sliding window without the optimization of assigning pane which means
* slideSize < windowSize && ArithmeticUtils.gcd(windowSize, slideSize) == 1, we will disable the * slideSize < windowSize && ArithmeticUtils.gcd(windowSize, slideSize) == 1, we will disable the
* local aggregate. * local aggregate.
* Otherwise, we use the same way as the group aggregate to make the decision whether * Otherwise, we use the same way as the group aggregate to make the decision whether
* to use a local aggregate or not. * to use a local aggregate or not.
*/ */
private def supportLocalWindowAgg( private def supportLocalWindowAgg(
call: RelOptRuleCall, call: RelOptRuleCall,
tableConfig: TableConfig, tableConfig: TableConfig,
......
...@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList ...@@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashAggregate, BatchExecLocalSortAggregate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashAggregate, BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.AggregatePhaseStrategy
import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy
...@@ -36,12 +36,11 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -36,12 +36,11 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation, RelNode} import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation, RelNode}
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.Util import org.apache.calcite.util.Util
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
trait BatchExecAggRuleBase { trait BatchPhysicalAggRuleBase {
protected def inferLocalAggType( protected def inferLocalAggType(
inputRowType: RelDataType, inputRowType: RelDataType,
...@@ -185,7 +184,6 @@ trait BatchExecAggRuleBase { ...@@ -185,7 +184,6 @@ trait BatchExecAggRuleBase {
protected def createLocalAgg( protected def createLocalAgg(
cluster: RelOptCluster, cluster: RelOptCluster,
relBuilder: RelBuilder,
traitSet: RelTraitSet, traitSet: RelTraitSet,
input: RelNode, input: RelNode,
originalAggRowType: RelDataType, originalAggRowType: RelDataType,
...@@ -193,7 +191,7 @@ trait BatchExecAggRuleBase { ...@@ -193,7 +191,7 @@ trait BatchExecAggRuleBase {
auxGrouping: Array[Int], auxGrouping: Array[Int],
aggBufferTypes: Array[Array[DataType]], aggBufferTypes: Array[Array[DataType]],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
isLocalHashAgg: Boolean): BatchExecGroupAggregateBase = { isLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = {
val inputRowType = input.getRowType val inputRowType = input.getRowType
val aggFunctions = aggCallToAggFunction.map(_._2).toArray val aggFunctions = aggCallToAggFunction.map(_._2).toArray
...@@ -213,7 +211,6 @@ trait BatchExecAggRuleBase { ...@@ -213,7 +211,6 @@ trait BatchExecAggRuleBase {
if (isLocalHashAgg) { if (isLocalHashAgg) {
new BatchExecLocalHashAggregate( new BatchExecLocalHashAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
input, input,
localAggRowType, localAggRowType,
...@@ -224,7 +221,6 @@ trait BatchExecAggRuleBase { ...@@ -224,7 +221,6 @@ trait BatchExecAggRuleBase {
} else { } else {
new BatchExecLocalSortAggregate( new BatchExecLocalSortAggregate(
cluster, cluster,
relBuilder,
traitSet, traitSet,
input, input,
localAggRowType, localAggRowType,
......
...@@ -22,31 +22,30 @@ import org.apache.flink.table.api.TableException ...@@ -22,31 +22,30 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex.RexUtil import org.apache.calcite.rex.RexUtil
import org.apache.calcite.tools.RelBuilder
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/** /**
* Planner rule that writes one phase aggregate to two phase aggregate, * Planner rule that writes one phase aggregate to two phase aggregate,
* when the following conditions are met: * when the following conditions are met:
* 1. there is no local aggregate, * 1. there is no local aggregate,
* 2. the aggregate has non-empty grouping and two phase aggregate strategy is enabled, * 2. the aggregate has non-empty grouping and two phase aggregate strategy is enabled,
* 3. the input is [[BatchPhysicalExpand]] and there is at least one expand row * 3. the input is [[BatchPhysicalExpand]] and there is at least one expand row
* which the columns for grouping are all constant. * which the columns for grouping are all constant.
*/ */
abstract class EnforceLocalAggRuleBase( abstract class EnforceLocalAggRuleBase(
operand: RelOptRuleOperand, operand: RelOptRuleOperand,
description: String) description: String)
extends RelOptRule(operand, description) extends RelOptRule(operand, description)
with BatchExecAggRuleBase { with BatchPhysicalAggRuleBase {
protected def isTwoPhaseAggEnabled(agg: BatchExecGroupAggregateBase): Boolean = { protected def isTwoPhaseAggEnabled(agg: BatchPhysicalGroupAggregateBase): Boolean = {
val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(agg) val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(agg)
val aggFunctions = agg.getAggCallToAggFunction.map(_._2).toArray val aggFunctions = agg.getAggCallToAggFunction.map(_._2).toArray
isTwoPhaseAggWorkable(aggFunctions, tableConfig) isTwoPhaseAggWorkable(aggFunctions, tableConfig)
...@@ -64,14 +63,13 @@ abstract class EnforceLocalAggRuleBase( ...@@ -64,14 +63,13 @@ abstract class EnforceLocalAggRuleBase(
} }
protected def createLocalAgg( protected def createLocalAgg(
completeAgg: BatchExecGroupAggregateBase, completeAgg: BatchPhysicalGroupAggregateBase,
input: RelNode, input: RelNode): BatchPhysicalGroupAggregateBase = {
relBuilder: RelBuilder): BatchExecGroupAggregateBase = {
val cluster = completeAgg.getCluster val cluster = completeAgg.getCluster
val inputRowType = input.getRowType val inputRowType = input.getRowType
val grouping = completeAgg.getGrouping val grouping = completeAgg.grouping
val auxGrouping = completeAgg.getAuxGrouping val auxGrouping = completeAgg.auxGrouping
val aggCalls = completeAgg.getAggCallList val aggCalls = completeAgg.getAggCallList
val aggCallToAggFunction = completeAgg.getAggCallToAggFunction val aggCallToAggFunction = completeAgg.getAggCallToAggFunction
...@@ -91,7 +89,6 @@ abstract class EnforceLocalAggRuleBase( ...@@ -91,7 +89,6 @@ abstract class EnforceLocalAggRuleBase(
createLocalAgg( createLocalAgg(
cluster, cluster,
relBuilder,
traitSet, traitSet,
input, input,
completeAgg.getRowType, completeAgg.getRowType,
...@@ -104,10 +101,10 @@ abstract class EnforceLocalAggRuleBase( ...@@ -104,10 +101,10 @@ abstract class EnforceLocalAggRuleBase(
} }
protected def createExchange( protected def createExchange(
completeAgg: BatchExecGroupAggregateBase, completeAgg: BatchPhysicalGroupAggregateBase,
input: RelNode): BatchPhysicalExchange = { input: RelNode): BatchPhysicalExchange = {
val cluster = completeAgg.getCluster val cluster = completeAgg.getCluster
val grouping = completeAgg.getGrouping val grouping = completeAgg.grouping
// local aggregate outputs group fields first, and then agg calls // local aggregate outputs group fields first, and then agg calls
val distributionFields = grouping.indices.map(Integer.valueOf) val distributionFields = grouping.indices.map(Integer.valueOf)
...@@ -121,11 +118,10 @@ abstract class EnforceLocalAggRuleBase( ...@@ -121,11 +118,10 @@ abstract class EnforceLocalAggRuleBase(
} }
protected def createGlobalAgg( protected def createGlobalAgg(
completeAgg: BatchExecGroupAggregateBase, completeAgg: BatchPhysicalGroupAggregateBase,
input: RelNode, input: RelNode): BatchPhysicalGroupAggregateBase = {
relBuilder: RelBuilder): BatchExecGroupAggregateBase = { val grouping = completeAgg.grouping
val grouping = completeAgg.getGrouping val auxGrouping = completeAgg.auxGrouping
val auxGrouping = completeAgg.getAuxGrouping
val aggCallToAggFunction = completeAgg.getAggCallToAggFunction val aggCallToAggFunction = completeAgg.getAggCallToAggFunction
val (newGrouping, newAuxGrouping) = getGlobalAggGroupSetPair(grouping, auxGrouping) val (newGrouping, newAuxGrouping) = getGlobalAggGroupSetPair(grouping, auxGrouping)
...@@ -138,7 +134,6 @@ abstract class EnforceLocalAggRuleBase( ...@@ -138,7 +134,6 @@ abstract class EnforceLocalAggRuleBase(
case _: BatchExecHashAggregate => case _: BatchExecHashAggregate =>
new BatchExecHashAggregate( new BatchExecHashAggregate(
completeAgg.getCluster, completeAgg.getCluster,
relBuilder,
completeAgg.getTraitSet, completeAgg.getTraitSet,
input, input,
aggRowType, aggRowType,
...@@ -151,7 +146,6 @@ abstract class EnforceLocalAggRuleBase( ...@@ -151,7 +146,6 @@ abstract class EnforceLocalAggRuleBase(
case _: BatchExecSortAggregate => case _: BatchExecSortAggregate =>
new BatchExecSortAggregate( new BatchExecSortAggregate(
completeAgg.getCluster, completeAgg.getCluster,
relBuilder,
completeAgg.getTraitSet, completeAgg.getTraitSet,
input, input,
aggRowType, aggRowType,
......
...@@ -18,36 +18,36 @@ ...@@ -18,36 +18,36 @@
package org.apache.flink.table.planner.plan.rules.physical.batch package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExchange, BatchPhysicalExpand, BatchExecHashAggregate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchPhysicalExchange, BatchPhysicalExpand}
import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.RelOptRuleCall import org.apache.calcite.plan.RelOptRuleCall
/** /**
* An [[EnforceLocalAggRuleBase]] that matches [[BatchExecHashAggregate]] * An [[EnforceLocalAggRuleBase]] that matches [[BatchExecHashAggregate]]
* *
* for example: select count(*) from t group by rollup (a, b) * for example: select count(*) from t group by rollup (a, b)
* The physical plan * The physical plan
* *
* {{{ * {{{
* HashAggregate(isMerge=[false], groupBy=[a, b, $e], select=[a, b, $e, COUNT(*)]) * HashAggregate(isMerge=[false], groupBy=[a, b, $e], select=[a, b, $e, COUNT(*)])
* +- Exchange(distribution=[hash[a, b, $e]]) * +- Exchange(distribution=[hash[a, b, $e]])
* +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]},
* {a=[$0], b=[null], $e=[1]}, * {a=[$0], b=[null], $e=[1]},
* {a=[null], b=[null], $e=[3]}]) * {a=[null], b=[null], $e=[3]}])
* }}} * }}}
* *
* will be rewritten to * will be rewritten to
* *
* {{{ * {{{
* HashAggregate(isMerge=[true], groupBy=[a, b, $e], select=[a, b, $e, Final_COUNT(count1$0)]) * HashAggregate(isMerge=[true], groupBy=[a, b, $e], select=[a, b, $e, Final_COUNT(count1$0)])
* +- Exchange(distribution=[hash[a, b, $e]]) * +- Exchange(distribution=[hash[a, b, $e]])
* +- LocalHashAggregate(groupBy=[a, b, $e], select=[a, b, $e, Partial_COUNT(*) AS count1$0] * +- LocalHashAggregate(groupBy=[a, b, $e], select=[a, b, $e, Partial_COUNT(*) AS count1$0]
* +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]},
* {a=[$0], b=[null], $e=[1]}, * {a=[$0], b=[null], $e=[1]},
* {a=[null], b=[null], $e=[3]}]) * {a=[null], b=[null], $e=[3]}])
* }}} * }}}
*/ */
class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase(
operand(classOf[BatchExecHashAggregate], operand(classOf[BatchExecHashAggregate],
operand(classOf[BatchPhysicalExchange], operand(classOf[BatchPhysicalExchange],
...@@ -60,7 +60,7 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( ...@@ -60,7 +60,7 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase(
val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg) val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg)
val grouping = agg.getGrouping val grouping = agg.grouping
val constantShuffleKey = hasConstantShuffleKey(grouping, expand) val constantShuffleKey = hasConstantShuffleKey(grouping, expand)
grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey
...@@ -70,9 +70,9 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( ...@@ -70,9 +70,9 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase(
val agg: BatchExecHashAggregate = call.rel(0) val agg: BatchExecHashAggregate = call.rel(0)
val expand: BatchPhysicalExpand = call.rel(2) val expand: BatchPhysicalExpand = call.rel(2)
val localAgg = createLocalAgg(agg, expand, call.builder) val localAgg = createLocalAgg(agg, expand)
val exchange = createExchange(agg, localAgg) val exchange = createExchange(agg, localAgg)
val globalAgg = createGlobalAgg(agg, exchange, call.builder) val globalAgg = createGlobalAgg(agg, exchange)
call.transformTo(globalAgg) call.transformTo(globalAgg)
} }
......
...@@ -66,7 +66,7 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( ...@@ -66,7 +66,7 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase(
val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg) val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg)
val grouping = agg.getGrouping val grouping = agg.grouping
val constantShuffleKey = hasConstantShuffleKey(grouping, expand) val constantShuffleKey = hasConstantShuffleKey(grouping, expand)
grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey
...@@ -76,17 +76,17 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( ...@@ -76,17 +76,17 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase(
val agg: BatchExecSortAggregate = call.rel(0) val agg: BatchExecSortAggregate = call.rel(0)
val expand: BatchPhysicalExpand = call.rel(3) val expand: BatchPhysicalExpand = call.rel(3)
val localGrouping = agg.getGrouping val localGrouping = agg.grouping
// create local sort // create local sort
val localSort = createSort(expand, localGrouping) val localSort = createSort(expand, localGrouping)
val localAgg = createLocalAgg(agg, localSort, call.builder) val localAgg = createLocalAgg(agg, localSort)
val exchange = createExchange(agg, localAgg) val exchange = createExchange(agg, localAgg)
// create global sort // create global sort
val globalGrouping = localGrouping.indices.toArray val globalGrouping = localGrouping.indices.toArray
val globalSort = createSort(exchange, globalGrouping) val globalSort = createSort(exchange, globalGrouping)
val globalAgg = createGlobalAgg(agg, globalSort, call.builder) val globalAgg = createGlobalAgg(agg, globalSort)
call.transformTo(globalAgg) call.transformTo(globalAgg)
} }
......
...@@ -26,9 +26,9 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} ...@@ -26,9 +26,9 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
/** /**
* There maybe exist a subTree like localHashAggregate -> globalHashAggregate which the middle * There maybe exist a subTree like localHashAggregate -> globalHashAggregate which the middle
* shuffle is removed. The rule could remove redundant localHashAggregate node. * shuffle is removed. The rule could remove redundant localHashAggregate node.
*/ */
class RemoveRedundantLocalHashAggRule extends RelOptRule( class RemoveRedundantLocalHashAggRule extends RelOptRule(
operand(classOf[BatchExecHashAggregate], operand(classOf[BatchExecHashAggregate],
operand(classOf[BatchExecLocalHashAggregate], operand(classOf[BatchExecLocalHashAggregate],
...@@ -36,19 +36,18 @@ class RemoveRedundantLocalHashAggRule extends RelOptRule( ...@@ -36,19 +36,18 @@ class RemoveRedundantLocalHashAggRule extends RelOptRule(
"RemoveRedundantLocalHashAggRule") { "RemoveRedundantLocalHashAggRule") {
override def onMatch(call: RelOptRuleCall): Unit = { override def onMatch(call: RelOptRuleCall): Unit = {
val globalAgg = call.rels(0).asInstanceOf[BatchExecHashAggregate] val globalAgg: BatchExecHashAggregate = call.rel(0)
val localAgg = call.rels(1).asInstanceOf[BatchExecLocalHashAggregate] val localAgg: BatchExecLocalHashAggregate = call.rel(1)
val inputOfLocalAgg = localAgg.getInput val inputOfLocalAgg = localAgg.getInput
val newGlobalAgg = new BatchExecHashAggregate( val newGlobalAgg = new BatchExecHashAggregate(
globalAgg.getCluster, globalAgg.getCluster,
call.builder(),
globalAgg.getTraitSet, globalAgg.getTraitSet,
inputOfLocalAgg, inputOfLocalAgg,
globalAgg.getRowType, globalAgg.getRowType,
inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType,
inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType,
localAgg.getGrouping, localAgg.grouping,
localAgg.getAuxGrouping, localAgg.auxGrouping,
// Use the localAgg agg calls because the global agg call filters was removed, // Use the localAgg agg calls because the global agg call filters was removed,
// see BatchExecHashAggRule for details. // see BatchExecHashAggRule for details.
localAgg.getAggCallToAggFunction, localAgg.getAggCallToAggFunction,
......
...@@ -40,14 +40,13 @@ abstract class RemoveRedundantLocalSortAggRule( ...@@ -40,14 +40,13 @@ abstract class RemoveRedundantLocalSortAggRule(
val inputOfLocalAgg = getOriginalInputOfLocalAgg(call) val inputOfLocalAgg = getOriginalInputOfLocalAgg(call)
val newGlobalAgg = new BatchExecSortAggregate( val newGlobalAgg = new BatchExecSortAggregate(
globalAgg.getCluster, globalAgg.getCluster,
call.builder(),
globalAgg.getTraitSet, globalAgg.getTraitSet,
inputOfLocalAgg, inputOfLocalAgg,
globalAgg.getRowType, globalAgg.getRowType,
inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType,
inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType,
localAgg.getGrouping, localAgg.grouping,
localAgg.getAuxGrouping, localAgg.auxGrouping,
// Use the localAgg agg calls because the global agg call filters was removed, // Use the localAgg agg calls because the global agg call filters was removed,
// see BatchExecSortAggRule for details. // see BatchExecSortAggRule for details.
localAgg.getAggCallToAggFunction, localAgg.getAggCallToAggFunction,
......
...@@ -88,17 +88,17 @@ class FlinkRelMdHandlerTestBase { ...@@ -88,17 +88,17 @@ class FlinkRelMdHandlerTestBase {
// TODO batch RelNode and stream RelNode should have different PlannerContext // TODO batch RelNode and stream RelNode should have different PlannerContext
// and RelOptCluster due to they have different trait definitions. // and RelOptCluster due to they have different trait definitions.
val plannerContext: PlannerContext = val plannerContext: PlannerContext =
new PlannerContext( new PlannerContext(
tableConfig, tableConfig,
new FunctionCatalog(tableConfig, catalogManager, moduleManager), new FunctionCatalog(tableConfig, catalogManager, moduleManager),
catalogManager, catalogManager,
CalciteSchema.from(rootSchema), CalciteSchema.from(rootSchema),
util.Arrays.asList( util.Arrays.asList(
ConventionTraitDef.INSTANCE, ConventionTraitDef.INSTANCE,
FlinkRelDistributionTraitDef.INSTANCE, FlinkRelDistributionTraitDef.INSTANCE,
RelCollationTraitDef.INSTANCE RelCollationTraitDef.INSTANCE
)
) )
)
val typeFactory: FlinkTypeFactory = plannerContext.getTypeFactory val typeFactory: FlinkTypeFactory = plannerContext.getTypeFactory
val mq: FlinkRelMetadataQuery = FlinkRelMetadataQuery.instance() val mq: FlinkRelMetadataQuery = FlinkRelMetadataQuery.instance()
...@@ -981,7 +981,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -981,7 +981,6 @@ class FlinkRelMdHandlerTestBase {
val batchLocalAgg = new BatchExecLocalHashAggregate( val batchLocalAgg = new BatchExecLocalHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
studentBatchScan, studentBatchScan,
rowTypeOfLocalAgg, rowTypeOfLocalAgg,
...@@ -994,7 +993,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -994,7 +993,6 @@ class FlinkRelMdHandlerTestBase {
cluster, batchLocalAgg.getTraitSet.replace(hash0), batchLocalAgg, hash0) cluster, batchLocalAgg.getTraitSet.replace(hash0), batchLocalAgg, hash0)
val batchGlobalAgg = new BatchExecHashAggregate( val batchGlobalAgg = new BatchExecHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
batchExchange1, batchExchange1,
rowTypeOfGlobalAgg, rowTypeOfGlobalAgg,
...@@ -1009,7 +1007,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -1009,7 +1007,6 @@ class FlinkRelMdHandlerTestBase {
studentBatchScan.getTraitSet.replace(hash3), studentBatchScan, hash3) studentBatchScan.getTraitSet.replace(hash3), studentBatchScan, hash3)
val batchGlobalAggWithoutLocal = new BatchExecHashAggregate( val batchGlobalAggWithoutLocal = new BatchExecHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
batchExchange2, batchExchange2,
rowTypeOfGlobalAgg, rowTypeOfGlobalAgg,
...@@ -1111,7 +1108,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -1111,7 +1108,6 @@ class FlinkRelMdHandlerTestBase {
val batchLocalAggWithAuxGroup = new BatchExecLocalHashAggregate( val batchLocalAggWithAuxGroup = new BatchExecLocalHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
studentBatchScan, studentBatchScan,
rowTypeOfLocalAgg, rowTypeOfLocalAgg,
...@@ -1133,7 +1129,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -1133,7 +1129,6 @@ class FlinkRelMdHandlerTestBase {
.add("cnt", longType).build() .add("cnt", longType).build()
val batchGlobalAggWithAuxGroup = new BatchExecHashAggregate( val batchGlobalAggWithAuxGroup = new BatchExecHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
batchExchange, batchExchange,
rowTypeOfGlobalAgg, rowTypeOfGlobalAgg,
...@@ -1148,7 +1143,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -1148,7 +1143,6 @@ class FlinkRelMdHandlerTestBase {
studentBatchScan.getTraitSet.replace(hash0), studentBatchScan, hash0) studentBatchScan.getTraitSet.replace(hash0), studentBatchScan, hash0)
val batchGlobalAggWithoutLocalWithAuxGroup = new BatchExecHashAggregate( val batchGlobalAggWithoutLocalWithAuxGroup = new BatchExecHashAggregate(
cluster, cluster,
relBuilder,
batchPhysicalTraits, batchPhysicalTraits,
batchExchange2, batchExchange2,
rowTypeOfGlobalAgg, rowTypeOfGlobalAgg,
...@@ -2416,49 +2410,6 @@ class FlinkRelMdHandlerTestBase { ...@@ -2416,49 +2410,6 @@ class FlinkRelMdHandlerTestBase {
.scan("MyTable2") .scan("MyTable2")
.minus(false).build() .minus(false).build()
private def createGlobalAgg(
table: String, groupBy: String, sum: String): BatchExecHashAggregate = {
val scan: BatchPhysicalBoundedStreamScan =
createDataStreamScan(ImmutableList.of(table), batchPhysicalTraits)
relBuilder.push(scan)
val groupByField = relBuilder.field(groupBy)
val sumField = relBuilder.field(sum)
val hash = FlinkRelDistribution.hash(Array(groupByField.getIndex), requireStrict = true)
val exchange = new BatchPhysicalExchange(cluster, batchPhysicalTraits.replace(hash), scan, hash)
relBuilder.push(exchange)
val logicalAgg = relBuilder.aggregate(
relBuilder.groupKey(groupBy),
relBuilder.aggregateCall(SqlStdOperatorTable.SUM, relBuilder.field(sum))
).build().asInstanceOf[LogicalAggregate]
val aggCalls = logicalAgg.getAggCallList
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
val rowTypeOfGlobalAgg = typeFactory.builder
.add(groupByField.getName, groupByField.getType)
.add(sumField.getName, sumField.getType).build()
new BatchExecHashAggregate(
cluster,
relBuilder,
batchPhysicalTraits,
exchange,
rowTypeOfGlobalAgg,
exchange.getRowType,
exchange.getRowType,
Array(groupByField.getIndex),
auxGrouping = Array(),
aggCallToAggFunction,
isMerge = false)
}
protected def createDataStreamScan[T]( protected def createDataStreamScan[T](
tableNames: util.List[String], traitSet: RelTraitSet): T = { tableNames: util.List[String], traitSet: RelTraitSet): T = {
val table = relBuilder val table = relBuilder
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.plan.metadata package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchPhysicalCorrelate} import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCorrelate, BatchPhysicalGroupAggregateBase}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.{Aggregate, Correlate} import org.apache.calcite.rel.core.{Aggregate, Correlate}
...@@ -42,11 +42,11 @@ import scala.collection.mutable ...@@ -42,11 +42,11 @@ import scala.collection.mutable
* for Aggregate and Correlate. * for Aggregate and Correlate.
* This test ensure two points. * This test ensure two points.
* 1. all subclasses of [[MetadataHandler]] have explicit metadata estimation * 1. all subclasses of [[MetadataHandler]] have explicit metadata estimation
* for [[Aggregate]] and [[BatchExecGroupAggregateBase]] or have no metadata estimation for * for [[Aggregate]] and [[BatchPhysicalGroupAggregateBase]] or have no metadata estimation for
* [[Aggregate]] and [[BatchExecGroupAggregateBase]] either. * [[Aggregate]] and [[BatchPhysicalGroupAggregateBase]] either.
* 2. all subclasses of [[MetadataHandler]] have explicit metadata estimation * 2. all subclasses of [[MetadataHandler]] have explicit metadata estimation
* for [[Correlate]] and [[BatchExecGroupAggregateBase]] or have no metadata estimation for * for [[Correlate]] and [[BatchPhysicalGroupAggregateBase]] or have no metadata estimation for
* [[Correlate]] and [[BatchExecGroupAggregateBase]] either. * [[Correlate]] and [[BatchPhysicalGroupAggregateBase]] either.
* Be cautious that if logical Aggregate and physical Aggregate or logical Correlate and physical * Be cautious that if logical Aggregate and physical Aggregate or logical Correlate and physical
* Correlate both are present in a MetadataHandler class, their metadata estimation should be same. * Correlate both are present in a MetadataHandler class, their metadata estimation should be same.
* This test does not check this point because every MetadataHandler could have different * This test does not check this point because every MetadataHandler could have different
...@@ -144,7 +144,7 @@ object MetadataHandlerConsistencyTest { ...@@ -144,7 +144,7 @@ object MetadataHandlerConsistencyTest {
@Parameterized.Parameters(name = "logicalNodeClass={0}, physicalNodeClass={1}") @Parameterized.Parameters(name = "logicalNodeClass={0}, physicalNodeClass={1}")
def parameters(): util.Collection[Array[Any]] = { def parameters(): util.Collection[Array[Any]] = {
Seq[Array[Any]]( Seq[Array[Any]](
Array(classOf[Aggregate], classOf[BatchExecGroupAggregateBase]), Array(classOf[Aggregate], classOf[BatchPhysicalGroupAggregateBase]),
Array(classOf[Correlate], classOf[BatchPhysicalCorrelate])) Array(classOf[Correlate], classOf[BatchPhysicalCorrelate]))
} }
} }
...@@ -31,8 +31,8 @@ import org.junit.Before ...@@ -31,8 +31,8 @@ import org.junit.Before
/** /**
* Test for [[EnforceLocalHashAggRule]]. * Test for [[EnforceLocalHashAggRule]].
*/ */
class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase { class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase {
@Before @Before
...@@ -60,10 +60,10 @@ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase { ...@@ -60,10 +60,10 @@ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase {
} }
/** /**
* Planner rule that ignore the [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] * Planner rule that ignore the [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]]
* value, and only enable one phase aggregate. * value, and only enable one phase aggregate.
* This rule only used for test. * This rule only used for test.
*/ */
class BatchExecHashAggRuleForOnePhase extends BatchExecHashAggRule { class BatchExecHashAggRuleForOnePhase extends BatchExecHashAggRule {
override protected def isTwoPhaseAggWorkable( override protected def isTwoPhaseAggWorkable(
aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = false aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = false
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册