提交 d1953818 编写于 作者: G godfrey he 提交者: Kurt Young

[FLINK-12559][table-planner-blink] Introduce metadata handlers on window aggregate

This closes #8487
上级 aca9c018
......@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JDouble
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase}
import org.apache.flink.table.plan.stats._
import org.apache.flink.table.plan.util.AggregateUtil
......@@ -62,6 +62,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
(rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls)
case rel: BatchExecGroupAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
case rel: BatchExecLocalHashWindowAggregate =>
val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping
(fullGrouping, rel.getAggCallList)
case rel: BatchExecLocalSortWindowAggregate =>
val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping
(fullGrouping, rel.getAggCallList)
case rel: BatchExecWindowAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
require(outputIdx >= fullGrouping.length)
......
......@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.metadata.FlinkMetadata.ColumnInterval
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
......@@ -51,12 +51,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
override def getDef: MetadataDef[ColumnInterval] = FlinkMetadata.ColumnInterval.DEF
/**
* Gets interval of the given column in TableScan.
* Gets interval of the given column on TableScan.
*
* @param ts TableScan RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in TableScan
* @return interval of the given column on TableScan
*/
def getColumnInterval(ts: TableScan, mq: RelMetadataQuery, index: Int): ValueInterval = {
val relOptTable = ts.getTable.asInstanceOf[FlinkRelOptTable]
......@@ -79,12 +79,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in Values.
* Gets interval of the given column on Values.
*
* @param values Values RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Values
* @return interval of the given column on Values
*/
def getColumnInterval(values: Values, mq: RelMetadataQuery, index: Int): ValueInterval = {
val tuples = values.tuples
......@@ -101,14 +101,14 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in Project.
* Gets interval of the given column on Project.
*
* Note: Only support the simple RexNode, e.g RexInputRef.
*
* @param project Project RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Project
* @return interval of the given column on Project
*/
def getColumnInterval(project: Project, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -130,12 +130,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in Filter.
* Gets interval of the given column on Filter.
*
* @param filter Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Filter
* @return interval of the given column on Filter
*/
def getColumnInterval(filter: Filter, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -148,12 +148,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in batch Calc.
* Gets interval of the given column on Calc.
*
* @param calc Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Filter
* @return interval of the given column on Calc
*/
def getColumnInterval(calc: Calc, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -249,12 +249,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in Exchange.
* Gets interval of the given column on Exchange.
*
* @param exchange Exchange RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Exchange
* @return interval of the given column on Exchange
*/
def getColumnInterval(exchange: Exchange, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -262,12 +262,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column of Sort.
* Gets interval of the given column on Sort.
*
* @param sort Sort to analyze
* @param sort Sort RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Sort
* @return interval of the given column on Sort
*/
def getColumnInterval(sort: Sort, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -275,9 +275,9 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column of Expand.
* Gets interval of the given column of Expand.
*
* @param expand expand to analyze
* @param expand expand RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch sort
......@@ -309,12 +309,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column of Rank.
* Gets interval of the given column on Rank.
*
* @param rank [[Rank]] instance to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch Rank
* @return interval of the given column on Rank
*/
def getColumnInterval(
rank: Rank,
......@@ -344,101 +344,106 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column in Aggregates.
* Gets interval of the given column on Aggregates.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Aggregate
* @return interval of the given column on Aggregate
*/
def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets intervals of the given column in batch Aggregate.
* Gets interval of the given column on batch group aggregate.
*
* @param aggregate Aggregate RelNode
* @param aggregate batch group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch Aggregate
* @return interval of the given column on batch group aggregate
*/
def getColumnInterval(
aggregate: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream group aggregate.
*
* @param aggregate stream group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream local group aggregate.
*
* @param aggregate stream local group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream local group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecLocalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream global group aggregate.
*
* @param aggregate stream global group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream global group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets intervals of the given column in batch OverWindowAggregate.
* Gets interval of the given column on window aggregate.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch OverWindowAggregate
* @param agg window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on window Aggregate
*/
def getColumnInterval(
aggregate: BatchExecOverAggregate,
agg: WindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = getColumnIntervalOfOverWindow(aggregate, mq, index)
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets intervals of the given column in batch OverWindowAggregate.
* Gets interval of the given column on batch window aggregate.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch OverWindowAggregate
* @param agg batch window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on batch window Aggregate
*/
def getColumnInterval(
aggregate: StreamExecOverAggregate,
agg: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = getColumnIntervalOfOverWindow(aggregate, mq, index)
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets intervals of the given column in calcite window.
* Gets interval of the given column on stream window aggregate.
*
* @param window Window RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in window
* @param agg stream window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream window Aggregate
*/
def getColumnInterval(
window: Window,
agg: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
getColumnIntervalOfOverWindow(window, mq, index)
}
private def getColumnIntervalOfOverWindow(
overWindow: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val input = overWindow.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
if (index < fieldsCountOfInput) {
fmq.getColumnInterval(input, index)
} else {
// cannot estimate aggregate function calls columnInterval.
null
}
}
// TODO supports window aggregate
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
private def estimateColumnIntervalOfAggregate(
aggregate: SingleRel,
......@@ -451,8 +456,16 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamExecLocalGroupAggregate => agg.grouping
case agg: StreamExecGlobalGroupAggregate => agg.grouping
case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping
case agg: StreamExecGroupWindowAggregate => agg.getGrouping
case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
case agg: BatchExecLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
}
if (index < groupSet.length) {
......@@ -513,6 +526,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamExecIncrementalGroupAggregate
if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchExecLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecHashAggregate if agg.isMerge =>
......@@ -542,6 +557,8 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
} else {
null
}
case agg: BatchExecWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList(aggCallIndex)
case _ => null
}
......@@ -580,12 +597,68 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column in Join.
* Gets interval of the given column on calcite window.
*
* @param window Window RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on window
*/
def getColumnInterval(
window: Window,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
getColumnIntervalOfOverAgg(window, mq, index)
}
/**
* Gets interval of the given column on batch over aggregate.
*
* @param agg batch over aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index he index of the given column
* @return interval of the given column on batch over aggregate.
*/
def getColumnInterval(
agg: BatchExecOverAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)
/**
* Gets interval of the given column on stream over aggregate.
*
* @param agg stream over aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index he index of the given column
* @return interval of the given column on stream over aggregate.
*/
def getColumnInterval(
agg: StreamExecOverAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)
private def getColumnIntervalOfOverAgg(
overAgg: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val input = overAgg.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
if (index < fieldsCountOfInput) {
fmq.getColumnInterval(input, index)
} else {
// cannot estimate aggregate function calls columnInterval.
null
}
}
/**
* Gets interval of the given column on Join.
*
* @param join Join RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Join
* @return interval of the given column on Join
*/
def getColumnInterval(join: Join, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -612,12 +685,12 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets interval of the given column in Union.
* Gets interval of the given column on Union.
*
* @param union Union RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Union
* @return interval of the given column on Union
*/
def getColumnInterval(union: Union, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
......@@ -628,7 +701,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
/**
* Gets intervals of the given column of RelSubset.
* Gets interval of the given column on RelSubset.
*
* @param subset RelSubset to analyze
* @param mq RelMetadataQuery instance
......
......@@ -20,8 +20,9 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JBoolean
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.nodes.FlinkRelNode
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
......@@ -305,9 +306,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
// group by keys form a unique key
val groupKey = ImmutableBitSet.range(rel.getGroupCount)
columns.contains(groupKey)
areColumnsUniqueOnAggregate(rel.getGroupSet.toArray, mq, columns, ignoreNulls)
}
def areColumnsUnique(
......@@ -316,9 +315,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
if (rel.isFinal) {
// group key of agg output always starts from 0
val outputGroupKey = ImmutableBitSet.range(rel.getGrouping.length)
columns.contains(outputGroupKey)
areColumnsUniqueOnAggregate(rel.getGrouping, mq, columns, ignoreNulls)
} else {
null
}
......@@ -329,9 +326,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
// group key of agg output always starts from 0
val outputGroupKey = ImmutableBitSet.range(rel.grouping.length)
columns.contains(outputGroupKey)
areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls)
}
def areColumnsUnique(
......@@ -339,7 +334,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
columns.contains(ImmutableBitSet.of(rel.grouping.toArray: _*))
areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls)
}
def areColumnsUnique(
......@@ -348,32 +343,105 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = null
// TODO supports window aggregate
private def areColumnsUniqueOnAggregate(
grouping: Array[Int],
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
// group key of agg output always starts from 0
val outputGroupKey = ImmutableBitSet.of(grouping.indices: _*)
columns.contains(outputGroupKey)
}
def areColumnsUnique(
rel: WindowAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
areColumnsUniqueOnWindowAggregate(
rel.getGroupSet.toArray,
rel.getNamedProperties,
rel.getRowType.getFieldCount,
mq,
columns,
ignoreNulls)
}
def areColumnsUnique(
rel: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
if (rel.isFinal) {
areColumnsUniqueOnWindowAggregate(
rel.getGrouping,
rel.getNamedProperties,
rel.getRowType.getFieldCount,
mq,
columns,
ignoreNulls)
} else {
null
}
}
def areColumnsUnique(
rel: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
areColumnsUniqueOnWindowAggregate(
rel.getGrouping,
rel.getWindowProperties,
rel.getRowType.getFieldCount,
mq,
columns,
ignoreNulls)
}
private def areColumnsUniqueOnWindowAggregate(
grouping: Array[Int],
namedProperties: Seq[NamedWindowProperty],
outputFieldCount: Int,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
if (namedProperties.nonEmpty) {
val begin = outputFieldCount - namedProperties.size
val end = outputFieldCount - 1
val keys = ImmutableBitSet.of(grouping.indices: _*)
(begin to end).map {
i => keys.union(ImmutableBitSet.of(i))
}.exists(columns.contains)
} else {
false
}
}
def areColumnsUnique(
rel: Window,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
def areColumnsUnique(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
def areColumnsUnique(
rel: StreamExecOverAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverWindow(rel, mq, columns, ignoreNulls)
ignoreNulls: Boolean): JBoolean = areColumnsUniqueOfOverAgg(rel, mq, columns, ignoreNulls)
private def areColumnsUniqueOfOverWindow(
overWindow: SingleRel,
private def areColumnsUniqueOfOverAgg(
overAgg: SingleRel,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
val input = overWindow.getInput
val input = overAgg.getInput
val inputFieldLength = input.getRowType.getFieldCount
val columnsBelongsToInput = ImmutableBitSet.of(columns.filter(_ < inputFieldLength).toList)
val isSubColumnsUnique = mq.areColumnsUnique(
......
......@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.{PlannerConfigOptions, TableException}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil, FlinkRexUtil, RankUtil}
......@@ -399,26 +399,86 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecWindowAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
}
// TODO supports window aggregate
def getDistinctRowCount(
rel: WindowAggregate,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble = {
val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
if (newPredicate == null || newPredicate.isAlwaysTrue) {
if (groupKey.isEmpty) {
return 1D
}
}
val fieldCnt = rel.getRowType.getFieldCount
val namedPropertiesCnt = rel.getNamedProperties.size
val namedWindowStartIndex = fieldCnt - namedPropertiesCnt
val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
if (groupKeyFromNamedWindow) {
// cannot estimate DistinctRowCount result when some group keys are from named windows
null
} else {
getDistinctRowCountOfAggregate(rel, mq, groupKey, newPredicate)
}
}
def getDistinctRowCount(
rel: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble = {
if (predicate == null || predicate.isAlwaysTrue) {
if (groupKey.isEmpty) {
return 1D
}
}
val newPredicate = if (rel.isFinal) {
val namedWindowStartIndex = rel.getRowType.getFieldCount - rel.getNamedProperties.size
val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
if (groupKeyFromNamedWindow) {
// cannot estimate DistinctRowCount result when some group keys are from named windows
return null
}
val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
if (rel.isMerge) {
// set the bits as they correspond to local window aggregate
val localWinAggGroupKey = FlinkRelMdUtil.setChildKeysOfWinAgg(groupKey, rel)
val childPredicate = FlinkRelMdUtil.setChildPredicateOfWinAgg(newPredicate, rel)
return mq.getDistinctRowCount(rel.getInput, localWinAggGroupKey, childPredicate)
} else {
newPredicate
}
} else {
// local window aggregate
val assignTsFieldIndex = rel.getGrouping.length
if (groupKey.toList.contains(assignTsFieldIndex)) {
// groupKey contains `assignTs` fields
return null
}
predicate
}
getDistinctRowCountOfAggregate(rel, mq, groupKey, newPredicate)
}
def getDistinctRowCount(
rel: Window,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble =
getDistinctRowCountOfOverWindow(rel, mq, groupKey, predicate)
predicate: RexNode): JDouble = getDistinctRowCountOfOverAgg(rel, mq, groupKey, predicate)
def getDistinctRowCount(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble =
getDistinctRowCountOfOverWindow(rel, mq, groupKey, predicate)
predicate: RexNode): JDouble = getDistinctRowCountOfOverAgg(rel, mq, groupKey, predicate)
private def getDistinctRowCountOfOverWindow(
overWindow: SingleRel,
private def getDistinctRowCountOfOverAgg(
overAgg: SingleRel,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet,
predicate: RexNode): JDouble = {
......@@ -427,10 +487,10 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
return 1D
}
}
val input = overWindow.getInput
val input = overAgg.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
val groupKeyContainsAggCall = groupKey.toList.exists(_ >= fieldsCountOfInput)
// cannot estimate ndv of aggCall result of OverWindowAgg
// cannot estimate ndv of aggCall result of OverAgg
if (groupKeyContainsAggCall) {
null
} else {
......@@ -441,7 +501,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
predicate,
pushable,
notPushable)
val rexBuilder = overWindow.getCluster.getRexBuilder
val rexBuilder = overAgg.getCluster.getRexBuilder
val childPreds = RexUtil.composeConjunction(rexBuilder, pushable, true)
val distinctRowCount = mq.getDistinctRowCount(input, groupKey, childPreds)
if (distinctRowCount == null) {
......@@ -450,7 +510,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata
distinctRowCount
} else {
val preds = RexUtil.composeConjunction(rexBuilder, notPushable, true)
val rowCount = mq.getRowCount(overWindow)
val rowCount = mq.getRowCount(overAgg)
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(rowCount, distinctRowCount,
RelMdUtil.guessSelectivity(preds))
}
......
......@@ -19,7 +19,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.metadata.FlinkMetadata.FilteredColumnInterval
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecLocalGroupAggregate}
import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecGroupWindowAggregate, StreamExecLocalGroupAggregate}
import org.apache.flink.table.plan.stats.ValueInterval
import org.apache.flink.table.plan.util.ColumnIntervalUtil
import org.apache.flink.util.Preconditions.checkArgument
......@@ -198,7 +198,13 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC
estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
}
// TODO support window aggregate
def getColumnInterval(
aggregate: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
columnIndex: Int,
filterArg: Int): ValueInterval = {
estimateFilteredColumnIntervalOfAggregate(aggregate, mq, columnIndex, filterArg)
}
def estimateFilteredColumnIntervalOfAggregate(
rel: RelNode,
......
......@@ -272,7 +272,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}
def getRelModifiedMonotonicity(
rel: FlinkLogicalOverWindow,
rel: FlinkLogicalOverAggregate,
mq: RelMetadataQuery): RelModifiedMonotonicity = constants(rel.getRowType.getFieldCount)
def getRelModifiedMonotonicity(
......
......@@ -19,7 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, RankUtil}
import org.apache.flink.table.{JArrayList, JDouble}
......@@ -272,26 +272,67 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P
NumberUtil.min(popSizeOfColsInGroupKeys * popSizeOfColsInAggCalls, inputRowCnt)
}
// TODO supports window aggregate
def getPopulationSize(
rel: WindowAggregate,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = {
val fieldCnt = rel.getRowType.getFieldCount
val namedPropertiesCnt = rel.getNamedProperties.size
val namedWindowStartIndex = fieldCnt - namedPropertiesCnt
val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
if (groupKeyFromNamedWindow) {
// cannot estimate PopulationSize result when some group keys are from named windows
null
} else {
// regular aggregate
getPopulationSize(rel.asInstanceOf[Aggregate], mq, groupKey)
}
}
def getPopulationSize(
rel: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = {
if (rel.isFinal) {
val namedWindowStartIndex = rel.getRowType.getFieldCount - rel.getNamedProperties.size
val groupKeyFromNamedWindow = groupKey.toList.exists(_ >= namedWindowStartIndex)
if (groupKeyFromNamedWindow) {
return null
}
if (rel.isMerge) {
// set the bits as they correspond to local window aggregate
val localWinAggGroupKey = FlinkRelMdUtil.setChildKeysOfWinAgg(groupKey, rel)
return mq.getPopulationSize(rel.getInput, localWinAggGroupKey)
}
} else {
// local window aggregate
val assignTsFieldIndex = rel.getGrouping.length
if (groupKey.toList.contains(assignTsFieldIndex)) {
// groupKey contains `assignTs` fields
return null
}
}
getPopulationSizeOfAggregate(rel, mq, groupKey)
}
def getPopulationSize(
window: Window,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverWindow(window, mq, groupKey)
groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverAgg(window, mq, groupKey)
def getPopulationSize(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverWindow(rel, mq, groupKey)
groupKey: ImmutableBitSet): JDouble = getPopulationSizeOfOverAgg(rel, mq, groupKey)
private def getPopulationSizeOfOverWindow(
overWindow: SingleRel,
private def getPopulationSizeOfOverAgg(
overAgg: SingleRel,
mq: RelMetadataQuery,
groupKey: ImmutableBitSet): JDouble = {
val input = overWindow.getInput
val input = overAgg.getInput
val fieldsCountOfInput = input.getRowType.getFieldCount
val groupKeyContainsAggCall = groupKey.toList.exists(_ >= fieldsCountOfInput)
// cannot estimate population size of aggCall result of OverWindowAgg
// cannot estimate population size of aggCall result of OverAgg
if (groupKeyContainsAggCall) {
null
} else {
......
......@@ -20,10 +20,12 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.JDouble
import org.apache.flink.table.calcite.FlinkContext
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.exec.NodeResourceConfig
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.stats.ValueInterval
import org.apache.flink.table.plan.util.AggregateUtil.{extractTimeIntervalValue, isTimeIntervalType}
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, SortUtil}
import org.apache.calcite.adapter.enumerable.EnumerableLimit
......@@ -142,6 +144,8 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
val (grouping, isFinal, isMerge) = rel match {
case agg: BatchExecGroupAggregateBase =>
(ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge)
case windowAgg: BatchExecWindowAggregateBase =>
(ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge)
case _ => throw new IllegalArgumentException(s"Unknown aggregate type ${rel.getRelTypeName}!")
}
val ndvOfGroupKeysOnGlobalAgg: JDouble = if (grouping.isEmpty) {
......@@ -185,16 +189,56 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun
}
}
// TODO supports window aggregate
def getRowCount(rel: WindowAggregate, mq: RelMetadataQuery): JDouble = {
val (ndvOfGroupKeys, inputRowCount) = getRowCountOfAgg(rel, rel.getGroupSet, 1, mq)
estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
}
def getRowCount(rel: BatchExecWindowAggregateBase, mq: RelMetadataQuery): JDouble = {
val ndvOfGroupKeys = getRowCountOfBatchExecAgg(rel, mq)
val inputRowCount = mq.getRowCount(rel.getInput)
estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
}
private def estimateRowCountOfWindowAgg(
ndv: JDouble,
inputRowCount: JDouble,
window: LogicalWindow): JDouble = {
if (ndv == null) {
null
} else {
// simply assume expand factor of TumblingWindow/SessionWindow/SlideWindowWithoutOverlap is 2
// SlideWindowWithOverlap is 4.
// Introduce expand factor here to distinguish output rowCount of normal agg with all kinds of
// window aggregates.
val expandFactorOfTumblingWindow = 2D
val expandFactorOfNoOverLapSlidingWindow = 2D
val expandFactorOfOverLapSlidingWindow = 4D
val expandFactorOfSessionWindow = 2D
window match {
case TumblingGroupWindow(_, _, size) if isTimeIntervalType(size.getType) =>
Math.min(expandFactorOfTumblingWindow * ndv, inputRowCount)
case SlidingGroupWindow(_, _, size, slide) if isTimeIntervalType(size.getType) =>
val sizeValue = extractTimeIntervalValue(size)
val slideValue = extractTimeIntervalValue(slide)
if (sizeValue > slideValue) {
// only slideWindow which has overlap may generates more records than input
expandFactorOfOverLapSlidingWindow * ndv
} else {
Math.min(expandFactorOfNoOverLapSlidingWindow * ndv, inputRowCount)
}
case _ => Math.min(expandFactorOfSessionWindow * ndv, inputRowCount)
}
}
}
def getRowCount(rel: Window, mq: RelMetadataQuery): JDouble =
getRowCountOfOverWindow(rel, mq)
def getRowCount(rel: Window, mq: RelMetadataQuery): JDouble = getRowCountOfOverAgg(rel, mq)
def getRowCount(rel: BatchExecOverAggregate, mq: RelMetadataQuery): JDouble =
getRowCountOfOverWindow(rel, mq)
getRowCountOfOverAgg(rel, mq)
private def getRowCountOfOverWindow(overWindow: SingleRel, mq: RelMetadataQuery): JDouble =
mq.getRowCount(overWindow.getInput)
private def getRowCountOfOverAgg(overAgg: SingleRel, mq: RelMetadataQuery): JDouble =
mq.getRowCount(overAgg.getInput)
def getRowCount(join: Join, mq: RelMetadataQuery): JDouble = {
join.getJoinType match {
......
......@@ -18,7 +18,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.FlinkRelMdUtil
import org.apache.flink.table.{JArrayList, JDouble}
......@@ -101,6 +101,26 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
mq: RelMetadataQuery,
predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate)
def getSelectivity(
rel: WindowAggregate,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = {
val newPredicate = FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
getSelectivityOfAgg(rel, mq, newPredicate)
}
def getSelectivity(
rel: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = {
val newPredicate = if (rel.isFinal) {
FlinkRelMdUtil.makeNamePropertiesSelectivityRexNode(rel, predicate)
} else {
predicate
}
getSelectivityOfAgg(rel, mq, newPredicate)
}
private def getSelectivityOfAgg(
agg: SingleRel,
mq: RelMetadataQuery,
......@@ -111,10 +131,17 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
val hasLocalAgg = agg match {
case _: Aggregate => false
case rel: BatchExecGroupAggregateBase => rel.isFinal && rel.isMerge
case rel: BatchExecWindowAggregateBase => rel.isFinal && rel.isMerge
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
if (hasLocalAgg) {
return mq.getSelectivity(agg.getInput, predicate)
val childPredicate = agg match {
case rel: BatchExecWindowAggregateBase =>
// set the predicate as they correspond to local window aggregate
FlinkRelMdUtil.setChildPredicateOfWinAgg(predicate, rel)
case _ => predicate
}
return mq.getSelectivity(agg.getInput, childPredicate)
}
val (childPred, restPred) = agg match {
......@@ -122,6 +149,8 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecGroupAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case rel: BatchExecWindowAggregateBase =>
FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate)
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
val childSelectivity = mq.getSelectivity(agg.getInput(), childPred.orNull)
......@@ -139,19 +168,17 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele
}
}
// TODO supports window aggregate
def getSelectivity(
overWindow: Window,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = getSelectivityOfOverWindowAgg(overWindow, mq, predicate)
predicate: RexNode): JDouble = getSelectivityOfOverAgg(overWindow, mq, predicate)
def getSelectivity(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = getSelectivityOfOverWindowAgg(rel, mq, predicate)
predicate: RexNode): JDouble = getSelectivityOfOverAgg(rel, mq, predicate)
private def getSelectivityOfOverWindowAgg(
private def getSelectivityOfOverAgg(
over: SingleRel,
mq: RelMetadataQuery,
predicate: RexNode): JDouble = {
......
......@@ -231,16 +231,16 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] {
}
def averageColumnSizes(overWindow: Window, mq: RelMetadataQuery): JList[JDouble] =
averageColumnSizesOfOverWindow(overWindow, mq)
averageColumnSizesOfOverAgg(overWindow, mq)
def averageColumnSizes(rel: BatchExecOverAggregate, mq: RelMetadataQuery): JList[JDouble] =
averageColumnSizesOfOverWindow(rel, mq)
averageColumnSizesOfOverAgg(rel, mq)
private def averageColumnSizesOfOverWindow(
overWindow: SingleRel,
private def averageColumnSizesOfOverAgg(
overAgg: SingleRel,
mq: RelMetadataQuery): JList[JDouble] = {
val inputFieldCount = overWindow.getInput.getRowType.getFieldCount
getColumnSizesFromInputOrType(overWindow, mq, (0 until inputFieldCount).zipWithIndex.toMap)
val inputFieldCount = overAgg.getInput.getRowType.getFieldCount
getColumnSizesFromInputOrType(overAgg, mq, (0 until inputFieldCount).zipWithIndex.toMap)
}
def averageColumnSizes(rel: Join, mq: RelMetadataQuery): JList[JDouble] = {
......
......@@ -18,15 +18,16 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.metadata.FlinkMetadata.UniqueGroups
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, RankUtil}
import org.apache.flink.table.plan.util.{AggregateUtil, FlinkRelMdUtil, RankUtil}
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.{Bug, ImmutableBitSet, Util}
......@@ -252,7 +253,62 @@ class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] {
}
}
// TODO support window aggregate
def getUniqueGroups(
agg: WindowAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = agg.getGroupSet.map(_.toInt).toArray
val namedProperties = agg.getNamedProperties
val (auxGroupSet, _) = AggregateUtil.checkAndSplitAggCalls(agg)
getUniqueGroupsOfWindowAgg(agg, grouping, auxGroupSet, namedProperties, mq, columns)
}
def getUniqueGroups(
agg: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = agg.getGrouping
val namedProperties = agg.getNamedProperties
getUniqueGroupsOfWindowAgg(agg, grouping, agg.getAuxGrouping, namedProperties, mq, columns)
}
private def getUniqueGroupsOfWindowAgg(
windowAgg: SingleRel,
grouping: Array[Int],
auxGrouping: Array[Int],
namedProperties: Seq[NamedWindowProperty],
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val fieldCount = windowAgg.getRowType.getFieldCount
val columnList = columns.toList
val groupingInToOutMap = new mutable.HashMap[Integer, Integer]()
columnList.foreach { column =>
require(column < fieldCount)
if (column < grouping.length) {
groupingInToOutMap.put(grouping(column), column)
}
}
if (groupingInToOutMap.isEmpty) {
columns
} else {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputColumns = ImmutableBitSet.of(groupingInToOutMap.keys.toList)
val inputUniqueGroups = fmq.getUniqueGroups(windowAgg.getInput, inputColumns)
val uniqueGroupsFromGrouping = inputUniqueGroups.asList.map { i =>
groupingInToOutMap.getOrElse(i, throw new IllegalArgumentException(s"Illegal index: $i"))
}
val fullGroupingOutputIndices =
grouping.indices ++ auxGrouping.indices.map(_ + grouping.length)
if (columns.equals(ImmutableBitSet.of(fullGroupingOutputIndices: _*))) {
return ImmutableBitSet.of(uniqueGroupsFromGrouping)
}
val groupingOutCols = groupingInToOutMap.values
// TODO drop some nonGroupingCols base on FlinkRelMdColumnUniqueness#areColumnsUnique(window)
val nonGroupingCols = columnList.filterNot(groupingOutCols.contains)
ImmutableBitSet.of(uniqueGroupsFromGrouping).union(ImmutableBitSet.of(nonGroupingCols))
}
}
def getUniqueGroups(
over: Window,
......
......@@ -18,7 +18,8 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
......@@ -32,7 +33,7 @@ import com.google.common.collect.ImmutableSet
import org.apache.calcite.plan.RelOptTable
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{JoinRelType, _}
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
......@@ -262,8 +263,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
rel: Aggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
// group by keys form a unique key
ImmutableSet.of(ImmutableBitSet.range(rel.getGroupCount))
getUniqueKeysOnAggregate(rel.getGroupSet.toArray, mq, ignoreNulls)
}
def getUniqueKeys(
......@@ -271,8 +271,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
if (rel.isFinal) {
// group by keys form a unique key
ImmutableSet.of(ImmutableBitSet.of(rel.getGrouping.indices: _*))
getUniqueKeysOnAggregate(rel.getGrouping, mq, ignoreNulls)
} else {
null
}
......@@ -282,8 +281,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
rel: StreamExecGroupAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
// group by keys form a unique key
toImmutableSet(rel.grouping.indices.toArray)
getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
}
def getUniqueKeys(
......@@ -291,44 +289,98 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = null
def getUniqueKeys(
rel: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
ImmutableSet.of(ImmutableBitSet.of(rel.grouping.indices.toArray: _*))
getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
}
def getUniqueKeysOnAggregate(
grouping: Array[Int],
mq: RelMetadataQuery,
ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
// group by keys form a unique key
ImmutableSet.of(ImmutableBitSet.of(grouping.indices: _*))
}
def getUniqueKeys(
rel: StreamExecWindowJoin,
rel: WindowAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
val joinInfo = JoinInfo.of(rel.getLeft, rel.getRight, rel.joinCondition)
getJoinUniqueKeys(joinInfo, rel.joinType, rel.getLeft, rel.getRight, mq, ignoreNulls)
ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
getUniqueKeysOnWindowAgg(
rel.getRowType.getFieldCount,
rel.getNamedProperties,
rel.getGroupSet.toArray,
mq,
ignoreNulls)
}
def getUniqueKeys(
rel: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
if (rel.isFinal) {
getUniqueKeysOnWindowAgg(
rel.getRowType.getFieldCount,
rel.getNamedProperties,
rel.getGrouping,
mq,
ignoreNulls)
} else {
null
}
}
def getUniqueKeys(
rel: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
getUniqueKeysOnWindowAgg(
rel.getRowType.getFieldCount, rel.getWindowProperties, rel.getGrouping, mq, ignoreNulls)
}
private def getUniqueKeysOnWindowAgg(
fieldCount: Int,
namedProperties: Seq[NamedWindowProperty],
grouping: Array[Int],
mq: RelMetadataQuery,
ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
if (namedProperties.nonEmpty) {
val begin = fieldCount - namedProperties.size
val end = fieldCount - 1
//namedProperties's indexes is at the end of output record
val keys = ImmutableBitSet.of(grouping.indices: _*)
(begin to end).map {
i => keys.union(ImmutableBitSet.of(i))
}.toSet[ImmutableBitSet]
} else {
null
}
}
def getUniqueKeys(
rel: Window,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
def getUniqueKeys(
rel: BatchExecOverAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
def getUniqueKeys(
rel: StreamExecOverAggregate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
getUniqueKeysOfOverWindow(rel, mq, ignoreNulls)
getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
}
private def getUniqueKeysOfOverWindow(
private def getUniqueKeysOfOverAgg(
window: SingleRel,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
......@@ -350,6 +402,14 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
}
}
def getUniqueKeys(
rel: StreamExecWindowJoin,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
val joinInfo = JoinInfo.of(rel.getLeft, rel.getRight, rel.joinCondition)
getJoinUniqueKeys(joinInfo, rel.joinType, rel.getLeft, rel.getRight, mq, ignoreNulls)
}
private def getJoinUniqueKeys(
joinInfo: JoinInfo,
joinRelType: JoinRelType,
......
......@@ -19,16 +19,15 @@
package org.apache.flink.table.plan.nodes.logical
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.calcite.plan._
import org.apache.calcite.rel.{RelCollation, RelCollationTraitDef, RelNode}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.logical.LogicalWindow
import org.apache.calcite.rel.metadata.RelMdCollation
import org.apache.calcite.rel.{RelCollation, RelCollationTraitDef, RelNode}
import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.sql.SqlRankFunction
......@@ -42,7 +41,7 @@ import scala.collection.JavaConversions._
* Sub-class of [[Window]] that is a relational expression
* which represents a set of over window aggregates in Flink.
*/
class FlinkLogicalOverWindow(
class FlinkLogicalOverAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
......@@ -53,7 +52,7 @@ class FlinkLogicalOverWindow(
with FlinkLogicalRel {
override def copy(traitSet: RelTraitSet, inputs: JList[RelNode]): RelNode = {
new FlinkLogicalOverWindow(
new FlinkLogicalOverAggregate(
cluster,
traitSet,
inputs.get(0),
......@@ -64,12 +63,12 @@ class FlinkLogicalOverWindow(
}
class FlinkLogicalOverWindowConverter
class FlinkLogicalOverAggregateConverter
extends ConverterRule(
classOf[LogicalWindow],
Convention.NONE,
FlinkConventions.LOGICAL,
"FlinkLogicalOverWindowConverter") {
"FlinkLogicalOverAggregateConverter") {
override def convert(rel: RelNode): RelNode = {
val window = rel.asInstanceOf[LogicalWindow]
......@@ -92,7 +91,7 @@ class FlinkLogicalOverWindowConverter
}
}
new FlinkLogicalOverWindow(
new FlinkLogicalOverAggregate(
rel.getCluster,
traitSet,
newInput,
......@@ -102,6 +101,6 @@ class FlinkLogicalOverWindowConverter
}
}
object FlinkLogicalOverWindow {
val CONVERTER = new FlinkLogicalOverWindowConverter
object FlinkLogicalOverAggregate {
val CONVERTER = new FlinkLogicalOverAggregateConverter
}
......@@ -18,7 +18,6 @@
package org.apache.flink.table.plan.nodes.logical
import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.FlinkConventions
import com.google.common.collect.ImmutableList
......
......@@ -42,7 +42,7 @@ class BatchExecLocalHashWindowAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
window: LogicalWindow,
inputTimeFieldIndex: Int,
val inputTimeFieldIndex: Int,
inputTimeIsDate: Boolean,
namedProperties: Seq[NamedWindowProperty],
enableAssignPane: Boolean = false)
......
......@@ -45,7 +45,7 @@ class BatchExecLocalSortWindowAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
window: LogicalWindow,
inputTimeFieldIndex: Int,
val inputTimeFieldIndex: Int,
inputTimeIsDate: Boolean,
namedProperties: Seq[NamedWindowProperty],
enableAssignPane: Boolean = false)
......
......@@ -58,7 +58,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
/**
* Batch physical RelNode for sort-based over [[Window]].
* Batch physical RelNode for sort-based over [[Window]] aggregate.
*/
class BatchExecOverAggregate(
cluster: RelOptCluster,
......
......@@ -62,7 +62,7 @@ class StreamExecGroupWindowAggregate(
val aggCalls: Seq[AggregateCall],
val window: LogicalWindow,
namedProperties: Seq[NamedWindowProperty],
inputTimestampIndex: Int,
inputTimeFieldIndex: Int,
val emitStrategy: WindowEmitStrategy)
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel
......@@ -86,7 +86,7 @@ class StreamExecGroupWindowAggregate(
case _ => false
}
def getGroupings: Array[Int] = grouping
def getGrouping: Array[Int] = grouping
def getWindowProperties: Seq[NamedWindowProperty] = namedProperties
......@@ -103,7 +103,7 @@ class StreamExecGroupWindowAggregate(
aggCalls,
window,
namedProperties,
inputTimestampIndex,
inputTimeFieldIndex,
emitStrategy)
}
......@@ -177,14 +177,14 @@ class StreamExecGroupWindowAggregate(
namedProperties)
val timeIdx = if (isRowtimeIndicatorType(window.timeAttribute.getResultType)) {
if (inputTimestampIndex < 0) {
if (inputTimeFieldIndex < 0) {
throw new TableException(
"Group window aggregate must defined on a time attribute, " +
"but the time attribute can't be found.\n" +
"This should never happen. Please file an issue."
)
}
inputTimestampIndex
inputTimeFieldIndex
} else {
-1
}
......
......@@ -267,7 +267,7 @@ object FlinkBatchRuleSets {
*/
private val LOGICAL_CONVERTERS: RuleSet = RuleSets.ofList(
FlinkLogicalAggregate.BATCH_CONVERTER,
FlinkLogicalOverWindow.CONVERTER,
FlinkLogicalOverAggregate.CONVERTER,
FlinkLogicalCalc.CONVERTER,
FlinkLogicalCorrelate.CONVERTER,
FlinkLogicalJoin.CONVERTER,
......@@ -329,7 +329,7 @@ object FlinkBatchRuleSets {
BatchExecNestedLoopJoinRule.INSTANCE,
BatchExecSingleRowJoinRule.INSTANCE,
BatchExecCorrelateRule.INSTANCE,
BatchExecOverWindowAggRule.INSTANCE,
BatchExecOverAggregateRule.INSTANCE,
BatchExecWindowAggregateRule.INSTANCE,
BatchExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN,
BatchExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN,
......
......@@ -241,7 +241,7 @@ object FlinkStreamRuleSets {
private val LOGICAL_CONVERTERS: RuleSet = RuleSets.ofList(
// translate to flink logical rel nodes
FlinkLogicalAggregate.STREAM_CONVERTER,
FlinkLogicalOverWindow.CONVERTER,
FlinkLogicalOverAggregate.CONVERTER,
FlinkLogicalCalc.CONVERTER,
FlinkLogicalCorrelate.CONVERTER,
FlinkLogicalJoin.CONVERTER,
......
......@@ -19,7 +19,7 @@ package org.apache.flink.table.plan.rules.logical
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkContext
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverWindow, FlinkLogicalRank}
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalOverAggregate, FlinkLogicalRank}
import org.apache.flink.table.plan.util.RankUtil
import org.apache.flink.table.runtime.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankType}
......@@ -33,17 +33,17 @@ import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
import scala.collection.JavaConversions._
/**
* Planner rule that matches a [[FlinkLogicalCalc]] on a [[FlinkLogicalOverWindow]],
* Planner rule that matches a [[FlinkLogicalCalc]] on a [[FlinkLogicalOverAggregate]],
* and converts them into a [[FlinkLogicalRank]].
*/
abstract class FlinkLogicalRankRuleBase
extends RelOptRule(
operand(classOf[FlinkLogicalCalc],
operand(classOf[FlinkLogicalOverWindow], any()))) {
operand(classOf[FlinkLogicalOverAggregate], any()))) {
override def onMatch(call: RelOptRuleCall): Unit = {
val calc: FlinkLogicalCalc = call.rel(0)
val window: FlinkLogicalOverWindow = call.rel(1)
val window: FlinkLogicalOverAggregate = call.rel(1)
val group = window.groups.get(0)
val rankFun = group.aggCalls.get(0).getOperator.asInstanceOf[SqlRankFunction]
......@@ -152,7 +152,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0)
val window: FlinkLogicalOverWindow = call.rel(1)
val window: FlinkLogicalOverAggregate = call.rel(1)
if (window.groups.size > 1) {
// only accept one window
......@@ -175,7 +175,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
val condition = calc.getProgram.getCondition
if (condition != null) {
val predicate = calc.getProgram.expandLocalRef(condition)
// the rank function is the last field of FlinkLogicalOverWindow
// the rank function is the last field of FlinkLogicalOverAggregate
val rankFieldIndex = window.getRowType.getFieldCount - 1
val config = calc.getCluster.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val (rankRange, remainingPreds) = RankUtil.extractRankRange(
......@@ -217,7 +217,7 @@ class FlinkLogicalRankRuleForRangeEnd extends FlinkLogicalRankRuleBase {
class FlinkLogicalRankRuleForConstantRange extends FlinkLogicalRankRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0)
val window: FlinkLogicalOverWindow = call.rel(1)
val window: FlinkLogicalOverAggregate = call.rel(1)
if (window.groups.size > 1) {
// only accept one window
......@@ -240,7 +240,7 @@ class FlinkLogicalRankRuleForConstantRange extends FlinkLogicalRankRuleBase {
val condition = calc.getProgram.getCondition
if (condition != null) {
val predicate = calc.getProgram.expandLocalRef(condition)
// the rank function is the last field of FlinkLogicalOverWindow
// the rank function is the last field of FlinkLogicalOverAggregate
val rankFieldIndex = window.getRowType.getFieldCount - 1
val config = calc.getCluster.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig
val (rankRange, remainingPreds) = RankUtil.extractRankRange(
......
......@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.rules.physical.batch
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverWindow
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverAggregate
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecOverAggregate
import org.apache.flink.table.plan.util.{AggregateUtil, OverAggregateUtil, SortUtil}
......@@ -37,18 +37,18 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
/**
* Rule that converts [[FlinkLogicalOverWindow]] to one or more [[BatchExecOverAggregate]]s.
* Rule that converts [[FlinkLogicalOverAggregate]] to one or more [[BatchExecOverAggregate]]s.
* If there are more than one [[Group]], this rule will combine adjacent [[Group]]s with the
* same partition keys and order keys into one BatchExecOverAggregate.
*/
class BatchExecOverWindowAggRule
class BatchExecOverAggregateRule
extends RelOptRule(
operand(classOf[FlinkLogicalOverWindow],
operand(classOf[FlinkLogicalOverAggregate],
operand(classOf[RelNode], any)),
"BatchExecOverWindowAggRule") {
"BatchExecOverAggregateRule") {
override def onMatch(call: RelOptRuleCall): Unit = {
val logicWindow: FlinkLogicalOverWindow = call.rel(0)
val logicWindow: FlinkLogicalOverAggregate = call.rel(0)
var input: RelNode = call.rel(1)
var inputRowType = logicWindow.getInput.getRowType
val typeFactory = logicWindow.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
......@@ -135,7 +135,7 @@ class BatchExecOverWindowAggRule
/**
* Returns true if group1 satisfies group2 on keys and orderKeys, else false.
*/
def satisfies(group1: Group, group2: Group, logicWindow: FlinkLogicalOverWindow): Boolean = {
def satisfies(group1: Group, group2: Group, logicWindow: FlinkLogicalOverAggregate): Boolean = {
var isSatisfied = false
val keyComp = group1.keys.compareTo(group2.keys)
if (keyComp == 0) {
......@@ -176,6 +176,6 @@ class BatchExecOverWindowAggRule
}
}
object BatchExecOverWindowAggRule {
val INSTANCE: RelOptRule = new BatchExecOverWindowAggRule
object BatchExecOverAggregateRule {
val INSTANCE: RelOptRule = new BatchExecOverAggregateRule
}
......@@ -21,7 +21,7 @@ package org.apache.flink.table.plan.rules.physical.stream
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverWindow
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverAggregate
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecOverAggregate
import org.apache.calcite.plan.RelOptRule
......@@ -29,19 +29,19 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
/**
* Rule that converts [[FlinkLogicalOverWindow]] to [[StreamExecOverAggregate]].
* Rule that converts [[FlinkLogicalOverAggregate]] to [[StreamExecOverAggregate]].
* NOTES: StreamExecOverAggregate only supports one [[org.apache.calcite.rel.core.Window.Group]],
* else throw exception now
*/
class StreamExecOverAggregateRule
extends ConverterRule(
classOf[FlinkLogicalOverWindow],
classOf[FlinkLogicalOverAggregate],
FlinkConventions.LOGICAL,
FlinkConventions.STREAM_PHYSICAL,
"StreamExecOverAggregateRule") {
override def convert(rel: RelNode): RelNode = {
val logicWindow: FlinkLogicalOverWindow = rel.asInstanceOf[FlinkLogicalOverWindow]
val logicWindow: FlinkLogicalOverAggregate = rel.asInstanceOf[FlinkLogicalOverAggregate]
if (logicWindow.groups.size > 1) {
throw new TableException(
......
......@@ -17,18 +17,8 @@
*/
package org.apache.flink.table.plan.util
import java.lang.{Long => JLong}
import java.time.Duration
import java.util
import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rex.RexInputRef
import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.validate.SqlMonotonicity
import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types}
import org.apache.flink.table.JLong
import org.apache.flink.table.`type`.InternalTypes._
import org.apache.flink.table.`type`.{DecimalType, InternalType, InternalTypes, TypeConverters}
import org.apache.flink.table.api.{TableConfig, TableConfigOptions, TableException}
......@@ -45,7 +35,18 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.plan.`trait`.RelModifiedMonotonicity
import org.apache.flink.table.runtime.bundle.trigger.CountBundleTrigger
import org.apache.flink.table.typeutils._
import org.apache.flink.table.typeutils.{BaseRowTypeInfo, BinaryStringTypeInfo, DecimalTypeInfo, MapViewTypeInfo, TimeIndicatorTypeInfo, TimeIntervalTypeInfo}
import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rex.RexInputRef
import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.validate.SqlMonotonicity
import org.apache.calcite.sql.{SqlKind, SqlRankFunction}
import org.apache.calcite.tools.RelBuilder
import java.time.Duration
import java.util
import scala.collection.JavaConversions._
import scala.collection.mutable
......@@ -737,4 +738,12 @@ object AggregateUtil extends Enumeration {
throw new IllegalArgumentException()
}
}
def extractTimeIntervalValue(literal: ValueLiteralExpression): JLong = {
if (isTimeIntervalType(literal.getType)) {
literal.getValue.asInstanceOf[JLong]
} else {
throw new IllegalArgumentException()
}
}
}
......@@ -19,10 +19,11 @@
package org.apache.flink.table.plan.util
import org.apache.flink.table.JDouble
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataformat.BinaryRow
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank}
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecGroupAggregateBase
import org.apache.flink.table.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase}
import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange}
import org.apache.flink.table.runtime.sort.BinaryIndexedSortable
import org.apache.flink.table.typeutils.BinaryRowSerializer
......@@ -147,6 +148,80 @@ object FlinkRelMdUtil {
def getAggregationRatioIfNdvUnavailable(groupingLength: Int): JDouble =
1.0 - math.exp(-0.1 * groupingLength)
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param winAgg window aggregate node
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
winAgg: WindowAggregate,
predicate: RexNode): RexNode = {
val fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(winAgg)
makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param globalWinAgg global window aggregate node
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
globalWinAgg: BatchExecWindowAggregateBase,
predicate: RexNode): RexNode = {
require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!")
val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
makeNamePropertiesSelectivityRexNode(
globalWinAgg, fullGrouping, globalWinAgg.getNamedProperties, predicate)
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param winAgg window aggregate node
* @param fullGrouping full groupSets
* @param namedProperties NamedWindowProperty list
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
winAgg: SingleRel,
fullGrouping: Array[Int],
namedProperties: Seq[NamedWindowProperty],
predicate: RexNode): RexNode = {
if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) {
return predicate
}
val rexBuilder = winAgg.getCluster.getRexBuilder
val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size
// split non-nameProperties predicates and nameProperties predicates
val pushable = new util.ArrayList[RexNode]
val notPushable = new util.ArrayList[RexNode]
RelOptUtil.splitFilters(
ImmutableBitSet.range(0, namePropertiesStartIdx),
predicate,
pushable,
notPushable)
if (notPushable.nonEmpty) {
val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true)
val selectivity = RelMdUtil.guessSelectivity(pred)
val fun = rexBuilder.makeCall(
RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
pushable.add(fun)
}
RexUtil.composeConjunction(rexBuilder, pushable, true)
}
/**
* Estimates outputRowCount of local aggregate.
*
......@@ -212,10 +287,34 @@ object FlinkRelMdUtil {
setChildKeysOfAgg(groupKey, aggRel)
}
/**
* Takes a bitmap representing a set of input references and extracts the
* ones that reference the group by columns in an aggregate.
*
* @param groupKey the original bitmap
* @param aggRel the aggregate
*/
def setAggChildKeys(
groupKey: ImmutableBitSet,
aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
setChildKeysOfAgg(groupKey, aggRel)
}
private def setChildKeysOfAgg(
groupKey: ImmutableBitSet,
agg: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = {
val (aggCalls, fullGroupSet) = agg match {
case agg: BatchExecLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
(agg.getAggCallList,
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping)
case agg: BatchExecLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
(agg.getAggCallList,
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping)
case agg: BatchExecWindowAggregateBase =>
(agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case agg: BatchExecGroupAggregateBase =>
(agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}")
......@@ -237,7 +336,34 @@ object FlinkRelMdUtil {
}
/**
* Split groupKeys on Agregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
* Takes a bitmap representing a set of local window aggregate references.
*
* global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
* local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
*
* Skips `assignTs` when mapping `groupKey` to `childKey`.
*
* @param groupKey the original bitmap
* @param globalWinAgg the global window aggregate
*/
def setChildKeysOfWinAgg(
groupKey: ImmutableBitSet,
globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = {
require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
val childKeyBuilder = ImmutableBitSet.builder
groupKey.toArray.foreach { key =>
if (key < globalWinAgg.getGrouping.length) {
childKeyBuilder.set(key)
} else {
// skips `assignTs`
childKeyBuilder.set(key + 1)
}
}
childKeyBuilder.build()
}
/**
* Split groupKeys on Aggregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
* into keys on aggregate's groupKey and aggregate's aggregateCalls.
*
* @param agg the aggregate
......@@ -271,6 +397,10 @@ object FlinkRelMdUtil {
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
(childKeyExcludeAuxKey, aggCalls)
case rel: BatchExecWindowAggregateBase =>
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
(childKeyExcludeAuxKey, aggCalls)
case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}.")
}
}
......@@ -306,6 +436,44 @@ object FlinkRelMdUtil {
splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
}
/**
* Split a predicate on WindowAggregateBatchExecBase into two parts,
* the first one is pushable part, the second one is rest part.
*
* @param agg Aggregate which to analyze
* @param predicate Predicate which to analyze
* @return a tuple, first element is pushable part, second element is rest part.
* Note, pushable condition will be converted based on the input field position.
*/
def splitPredicateOnAggregate(
agg: BatchExecWindowAggregateBase,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
}
/**
* Shifts every [[RexInputRef]] in an expression higher than length of full grouping
* (for skips `assignTs`).
*
* global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
* local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
*
* @param predicate a RexNode
* @param globalWinAgg the global window aggregate
*/
def setChildPredicateOfWinAgg(
predicate: RexNode,
globalWinAgg: BatchExecWindowAggregateBase): RexNode = {
require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
if (predicate == null) {
return null
}
// grouping + assignTs + auxGrouping
val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
// skips `assignTs`
RexUtil.shift(predicate, fullGrouping.length, 1)
}
private def splitPredicateOnAgg(
grouping: Array[Int],
agg: SingleRel,
......
......@@ -596,10 +596,11 @@ LogicalSink(fields=[a, sum_c, time])
</Resource>
<Resource name="planAfter">
<![CDATA[
HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, SUM(c) AS sum_c], reuse_id=[1])
HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, Final_SUM(sum$0) AS sum_c], reuse_id=[1])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[ts, a, CAST(c) AS c])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, ts)]]], fields=[a, b, c, ts])
+- LocalHashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[a, Partial_SUM(c) AS sum$0])
+- Calc(select=[ts, a, CAST(c) AS c])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, ts)]]], fields=[a, b, c, ts])
Sink(fields=[a, sum_c, time, window_start, window_end])
+- Calc(select=[a, sum_c, w$end AS time, w$start AS window_start, w$end AS window_end])
......
......@@ -40,10 +40,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], EXPR$4=[TUMBL
<Resource name="planAfter">
<![CDATA[
Calc(select=[/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), $f2) AS EXPR$0, /(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), CASE(=($f2, 1), null:BIGINT, -($f2, 1))) AS EXPR$1, POWER(/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), $f2), 0.5:DECIMAL(2, 1)) AS EXPR$2, POWER(/(-($f0, /(*(CAST($f1), CAST($f1)), $f2)), CASE(=($f2, 1), null:BIGINT, -($f2, 1))), 0.5:DECIMAL(2, 1)) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[SUM($f2) AS $f0, SUM(b) AS $f1, COUNT(b) AS $f2])
+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS $f0, Final_SUM(sum$1) AS $f1, Final_COUNT(count$2) AS $f2])
+- Exchange(distribution=[single])
+- Calc(select=[ts, b, *(CAST(b), CAST(b)) AS $f2])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM($f2) AS sum$0, Partial_SUM(b) AS sum$1, Partial_COUNT(b) AS count$2])
+- Calc(select=[ts, b, *(CAST(b), CAST(b)) AS $f2])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
......@@ -134,10 +135,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[HOP_START($0)])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, w$start AS EXPR$1], where=[AND(>($f1, 0), =(EXTRACT(FLAG(QUARTER), w$start), 1:BIGINT))])
+- HashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[COUNT(*) AS EXPR$0, SUM(a) AS $f1])
+- HashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[Final_COUNT(count1$0) AS EXPR$0, Final_SUM(sum$1) AS $f1])
+- Exchange(distribution=[single])
+- Calc(select=[ts, a])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+- LocalHashWindowAggregate(window=[SlidingGroupWindow('w$, ts, 60000.millis, 900000.millis)], properties=[w$start, w$end, w$rowtime], select=[Partial_COUNT(*) AS count1$0, Partial_SUM(a) AS sum$1])
+- Calc(select=[ts, a])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
......@@ -248,10 +250,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[HOP_START($0)], EXPR$2=[HOP_END($0)])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, w$start AS EXPR$1, w$end AS EXPR$2])
+- HashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[SUM(a) AS EXPR$0])
+- HashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS EXPR$0])
+- Exchange(distribution=[single])
+- Calc(select=[b, a])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+- LocalHashWindowAggregate(window=[SlidingGroupWindow('w$, b, 3000.millis, 3000.millis)], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM(a) AS sum$0])
+- Calc(select=[b, a])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
......@@ -326,10 +329,11 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2])
</Resource>
<Resource name="planAfter">
<![CDATA[
HashWindowAggregate(window=[TumblingGroupWindow], select=[AVG(c) AS EXPR$0, SUM(a) AS EXPR$1])
HashWindowAggregate(window=[TumblingGroupWindow], select=[Final_AVG(sum$0, count$1) AS EXPR$0, Final_SUM(sum$2) AS EXPR$1])
+- Exchange(distribution=[single])
+- Calc(select=[b, c, a])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+- LocalHashWindowAggregate(window=[TumblingGroupWindow], select=[Partial_AVG(c) AS (sum$0, count$1), Partial_SUM(a) AS sum$2])
+- Calc(select=[b, c, a])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
......@@ -434,10 +438,11 @@ LogicalProject(sumA=[$1], cntB=[$2])
</Resource>
<Resource name="planAfter">
<![CDATA[
HashWindowAggregate(window=[TumblingGroupWindow], select=[SUM(a) AS sumA, COUNT(b) AS cntB])
HashWindowAggregate(window=[TumblingGroupWindow], select=[Final_SUM(sum$0) AS sumA, Final_COUNT(count$1) AS cntB])
+- Exchange(distribution=[single])
+- Calc(select=[ts, a, b])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+- LocalHashWindowAggregate(window=[TumblingGroupWindow], select=[Partial_SUM(a) AS sum$0, Partial_COUNT(b) AS count$1])
+- Calc(select=[ts, a, b])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
......@@ -596,10 +601,11 @@ LogicalProject(EXPR$0=[TUMBLE_START($0)], EXPR$1=[TUMBLE_END($0)], EXPR$2=[TUMBL
<Resource name="planAfter">
<![CDATA[
Calc(select=[w$start AS EXPR$0, w$end AS EXPR$1, w$rowtime AS EXPR$2, c, sumA, minB])
+- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, SUM(a) AS sumA, MIN(b) AS minB])
+- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, Final_SUM(sum$0) AS sumA, Final_MIN(min$1) AS minB])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[ts, c, a, b])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+- LocalHashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c, Partial_SUM(a) AS sum$0, Partial_MIN(b) AS min$1])
+- Calc(select=[ts, c, a, b])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
......@@ -750,11 +756,13 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
+- SortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, MAX(c) AS EXPR$0])
+- Sort(orderBy=[a ASC, ts ASC])
+- SortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, Final_MAX(max$0) AS EXPR$0])
+- Sort(orderBy=[a ASC, assignedPane$ ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+- LocalSortWindowAggregate(groupBy=[a], window=[SlidingGroupWindow('w$, ts, 3600000.millis, 3000.millis)], select=[a, Partial_MAX(c) AS max$0])
+- Sort(orderBy=[a ASC, ts ASC])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
......@@ -1004,10 +1012,11 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
+- HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, COUNT(c) AS EXPR$0])
+- HashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Final_COUNT(count$0) AS EXPR$0])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+- LocalHashWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Partial_COUNT(c) AS count$0])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
......@@ -1071,9 +1080,10 @@ LogicalProject(EXPR$0=[$3], EXPR$1=[$4])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0, EXPR$1])
+- HashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, AVG(c) AS EXPR$0, COUNT(a) AS EXPR$1])
+- HashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, Final_AVG(sum$0, count$1) AS EXPR$0, Final_COUNT(count$2) AS EXPR$1])
+- Exchange(distribution=[hash[a, d]])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+- LocalHashWindowAggregate(groupBy=[a, d], window=[TumblingGroupWindow], select=[a, d, Partial_AVG(c) AS (sum$0, count$1), Partial_COUNT(a) AS count$2])
+- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
......@@ -1134,11 +1144,13 @@ LogicalProject(wAvg=[$1])
</Resource>
<Resource name="planAfter">
<![CDATA[
SortWindowAggregate(window=[TumblingGroupWindow], select=[weightedAvg(b, a) AS wAvg])
+- Sort(orderBy=[ts ASC])
SortWindowAggregate(window=[TumblingGroupWindow], select=[Final_weightedAvg(wAvg) AS wAvg])
+- Sort(orderBy=[assignedWindow$ ASC])
+- Exchange(distribution=[single])
+- Calc(select=[ts, b, a])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+- LocalSortWindowAggregate(window=[TumblingGroupWindow], select=[Partial_weightedAvg(b, a) AS wAvg])
+- Sort(orderBy=[ts ASC])
+- Calc(select=[ts, b, a])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
......@@ -1157,11 +1169,13 @@ LogicalProject(EXPR$0=[$2])
<Resource name="planAfter">
<![CDATA[
Calc(select=[EXPR$0])
+- SortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, MAX(c) AS EXPR$0])
+- Sort(orderBy=[a ASC, ts ASC])
+- SortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Final_MAX(max$0) AS EXPR$0])
+- Sort(orderBy=[a ASC, assignedWindow$ ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
+- LocalSortWindowAggregate(groupBy=[a], window=[TumblingGroupWindow], select=[a, Partial_MAX(c) AS max$0])
+- Sort(orderBy=[a ASC, ts ASC])
+- Calc(select=[a, ts, c])
+- TableSourceScan(table=[[MyTable1, source: [TestTableSource(ts, a, b, c)]]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
......@@ -1298,8 +1312,9 @@ LogicalProject(EXPR$0=[TUMBLE_END($0)])
Calc(select=[w$end AS EXPR$0])
+- HashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c])
+- Exchange(distribution=[hash[c]])
+- Calc(select=[ts, c])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
+- LocalHashWindowAggregate(groupBy=[c], window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[c])
+- Calc(select=[ts, c])
+- TableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c, d, ts)]]], fields=[a, b, c, d, ts])
]]>
</Resource>
</TestCase>
......
......@@ -61,7 +61,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2], rn=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -87,7 +87,7 @@ LogicalProject(a=[$0], b=[$1], rk1=[$2], rk2=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -161,7 +161,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, 2)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first, 2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first, 2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -233,7 +233,7 @@ LogicalProject(a=[$0], b=[$1], rn=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[<=(w0$o0, 2)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -257,7 +257,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[<(w0$o0, a)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -281,7 +281,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, a)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -305,7 +305,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[AND(<(w0$o0, a), >(CAST(b), 5:BIGINT))])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -329,7 +329,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[=(w0$o0, b)])
+- FlinkLogicalOverWindow(window#0=[window(partition {0} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {0} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -371,7 +371,7 @@ LogicalProject(a=[$0], b=[$1], rk=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULL
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0 AS $2])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......
......@@ -37,7 +37,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2], rn=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -63,7 +63,7 @@ LogicalProject(a=[$0], b=[$1], rk1=[$2], rk2=[$3])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0, w1$o0], where=[<(w0$o0, 10)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -235,7 +235,7 @@ LogicalProject(a=[$0], b=[$1], rk=[$2])
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0], where=[>(w0$o0, a)])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......@@ -325,7 +325,7 @@ LogicalProject(a=[$0], b=[$1], rk=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULL
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, b, w0$o0 AS $2])
+- FlinkLogicalOverWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalOverAggregate(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])])
+- FlinkLogicalTableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
......
......@@ -28,7 +28,7 @@ import org.junit.Test
import java.sql.Timestamp
class OverWindowAggregateTest extends TableTestBase {
class OverAggregateTest extends TableTestBase {
private val util = batchTestUtil()
util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
......
......@@ -448,8 +448,38 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetColumnIntervalOnOverWindowAgg(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach {
def testGetColumnIntervalOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithLocalAgg,
batchGlobalWindowAggWithoutLocalAgg, streamWindowAgg).foreach { agg =>
assertEquals(ValueInterval(5, 45), mq.getColumnInterval(agg, 0))
assertEquals(null, mq.getColumnInterval(agg, 1))
assertEquals(RightSemiInfiniteValueInterval(0), mq.getColumnInterval(agg, 2))
assertEquals(null, mq.getColumnInterval(agg, 3))
}
assertEquals(ValueInterval(5, 45), mq.getColumnInterval(batchLocalWindowAgg, 0))
assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 1))
assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 2))
assertEquals(RightSemiInfiniteValueInterval(0), mq.getColumnInterval(batchLocalWindowAgg, 3))
assertEquals(null, mq.getColumnInterval(batchLocalWindowAgg, 4))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup).foreach { agg =>
assertEquals(ValueInterval(5, 55), mq.getColumnInterval(agg, 0))
assertEquals(ValueInterval(0, 50), mq.getColumnInterval(agg, 1))
assertEquals(ValueInterval(0, null), mq.getColumnInterval(agg, 2))
assertEquals(null, mq.getColumnInterval(agg, 3))
}
assertEquals(ValueInterval(5, 55), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 0))
assertEquals(null, mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 1))
assertEquals(ValueInterval(0, 50), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 2))
assertEquals(ValueInterval(0, null), mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 3))
assertEquals(null, mq.getColumnInterval(batchLocalWindowAggWithAuxGroup, 4))
}
@Test
def testGetColumnIntervalOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach {
agg =>
assertEquals(ValueInterval(0, null), mq.getColumnInterval(agg, 0))
assertEquals(null, mq.getColumnInterval(agg, 1))
......@@ -464,14 +494,14 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
assertNull(mq.getColumnInterval(agg, 10))
}
assertEquals(ValueInterval(0, null), mq.getColumnInterval(streamOverWindowAgg, 0))
assertEquals(null, mq.getColumnInterval(streamOverWindowAgg, 1))
assertEquals(ValueInterval(2.7, 4.8), mq.getColumnInterval(streamOverWindowAgg, 2))
assertEquals(ValueInterval(12, 18), mq.getColumnInterval(streamOverWindowAgg, 3))
assertNull(mq.getColumnInterval(streamOverWindowAgg, 4))
assertNull(mq.getColumnInterval(streamOverWindowAgg, 5))
assertNull(mq.getColumnInterval(streamOverWindowAgg, 6))
assertNull(mq.getColumnInterval(streamOverWindowAgg, 7))
assertEquals(ValueInterval(0, null), mq.getColumnInterval(streamOverAgg, 0))
assertEquals(null, mq.getColumnInterval(streamOverAgg, 1))
assertEquals(ValueInterval(2.7, 4.8), mq.getColumnInterval(streamOverAgg, 2))
assertEquals(ValueInterval(12, 18), mq.getColumnInterval(streamOverAgg, 3))
assertNull(mq.getColumnInterval(streamOverAgg, 4))
assertNull(mq.getColumnInterval(streamOverAgg, 5))
assertNull(mq.getColumnInterval(streamOverAgg, 6))
assertNull(mq.getColumnInterval(streamOverAgg, 7))
}
@Test
......@@ -490,7 +520,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase {
assertEquals(ValueInterval(1L, 800000000L), mq.getColumnInterval(join, 1))
assertNull(mq.getColumnInterval(join, 2))
assertNull(mq.getColumnInterval(join, 3))
assertEquals(ValueInterval(1L, 100L),mq.getColumnInterval(join, 4))
assertEquals(ValueInterval(1L, 100L), mq.getColumnInterval(join, 4))
assertNull(mq.getColumnInterval(join, 5))
assertEquals(ValueInterval(8L, 1000L), mq.getColumnInterval(join, 6))
assertNull(mq.getColumnInterval(join, 7))
......
......@@ -355,8 +355,57 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testAreColumnsUniqueOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testAreColumnsUniqueOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithLocalAgg,
batchGlobalWindowAggWithoutLocalAgg, streamWindowAgg).foreach { agg =>
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 2)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 3)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 4)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 5)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 6)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 3, 4, 5, 6)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2, 3)))
}
assertNull(mq.areColumnsUnique(batchLocalWindowAgg, ImmutableBitSet.of(0, 1)))
assertNull(mq.areColumnsUnique(batchLocalWindowAgg, ImmutableBitSet.of(0, 1, 3)))
Array(logicalWindowAgg2, flinkLogicalWindowAgg2, batchGlobalWindowAggWithLocalAgg2,
batchGlobalWindowAggWithoutLocalAgg2, streamWindowAgg2).foreach { agg =>
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 4)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 5)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2, 3, 4, 5)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 3)))
}
assertNull(mq.areColumnsUnique(batchLocalWindowAgg2, ImmutableBitSet.of(0, 1)))
assertNull(mq.areColumnsUnique(batchLocalWindowAgg2, ImmutableBitSet.of(0, 2)))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup, batchGlobalWindowAggWithoutLocalAggWithAuxGroup
).foreach { agg =>
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 2)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 1, 2)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 4)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 5)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 6)))
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 3, 4, 5, 6)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1, 3)))
}
assertNull(mq.areColumnsUnique(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
assertNull(mq.areColumnsUnique(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 3)))
}
@Test
def testAreColumnsUniqueOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(agg, ImmutableBitSet.of(2)))
......@@ -375,20 +424,20 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
assertTrue(mq.areColumnsUnique(agg, ImmutableBitSet.of(0, 10)))
assertNull(mq.areColumnsUnique(agg, ImmutableBitSet.of(5, 10)))
}
assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0)))
assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(2)))
assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(3)))
assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(4)))
assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(5)))
assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(6)))
assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(7)))
assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 1)))
assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 2)))
assertFalse(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 5)))
assertTrue(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(0, 7)))
assertNull(mq.areColumnsUnique(streamOverWindowAgg, ImmutableBitSet.of(5, 7)))
assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0)))
assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(2)))
assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(3)))
assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(4)))
assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(5)))
assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(6)))
assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(7)))
assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 1)))
assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 2)))
assertFalse(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 5)))
assertTrue(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(0, 7)))
assertNull(mq.areColumnsUnique(streamOverAgg, ImmutableBitSet.of(5, 7)))
}
@Test
......@@ -465,38 +514,38 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
@Test
def testAreColumnsUniqueOnIntersect(): Unit = {
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (0)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1)))
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (2)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (0, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (0)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (2)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (0, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(0)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(2)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(0, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersectAll, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(0)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(2)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(0, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersect, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalIntersect,
ImmutableBitSet.range(logicalIntersect.getRowType.getFieldCount)))
}
@Test
def testAreColumnsUniqueOnMinus(): Unit = {
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (0)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1)))
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (2)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (0, 2)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (0)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (2)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (0, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of (1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(0)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(2)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(0, 2)))
assertTrue(mq.areColumnsUnique(logicalMinusAll, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(0)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(2)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1, 2)))
assertFalse(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(0, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus,
ImmutableBitSet.range(logicalMinus.getRowType.getFieldCount)))
......@@ -505,12 +554,12 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
.scan("MyTable2")
.scan("MyTable1")
.minus(false).build()
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (0)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1, 2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (0, 2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of (1, 2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(0)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1, 2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(0, 2)))
assertNull(mq.areColumnsUnique(logicalMinus2, ImmutableBitSet.of(1, 2)))
assertTrue(mq.areColumnsUnique(logicalMinus2,
ImmutableBitSet.range(logicalMinus2.getRowType.getFieldCount)))
}
......@@ -518,7 +567,7 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase {
@Test
def testGetColumnNullCountOnDefault(): Unit = {
(0 until testRel.getRowType.getFieldCount).foreach { idx =>
assertNull(mq.areColumnsUnique(testRel, ImmutableBitSet.of (idx)))
assertNull(mq.areColumnsUnique(testRel, ImmutableBitSet.of(idx)))
}
}
......
......@@ -19,6 +19,7 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank
import org.apache.flink.table.plan.util.FlinkRelMdUtil
import org.apache.calcite.rel.metadata.RelMdUtil
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
......@@ -428,8 +429,100 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetDistinctRowCountOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testGetDistinctRowCountOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
assertEquals(30D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
assertEquals(5D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), null))
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 2), null))
assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(3), null))
assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 3), null))
assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1, 3), null))
assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(2, 3), null))
relBuilder.clear()
// $1 > 10
val pred = relBuilder
.push(agg)
.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10))
assertEquals(
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 5.0D, 0.5D),
mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred), 1e-6)
assertEquals(25D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred))
// b > 10 and count(c) > 1 and w$end = 100000
val pred1 = relBuilder
.push(agg)
.and(
relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10)),
relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(1)),
relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(100000))
)
assertEquals(
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 5.0D, 0.075D),
mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred1), 1e-6)
assertEquals(25D * 0.15D * 1.0D,
mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred1), 1e-2)
}
assertEquals(30D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0), null))
assertEquals(5D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(1), null))
assertEquals(50D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 1), null))
assertEquals(null, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 2), null))
assertEquals(10D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(3), null))
assertEquals(50D, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(0, 3), null))
assertEquals(50.0, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(1, 3), null))
assertEquals(null, mq.getDistinctRowCount(batchLocalWindowAgg, ImmutableBitSet.of(2, 3), null))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
assertEquals(48D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), null))
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 2), null))
assertEquals(50D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1, 2), null))
assertEquals(null, mq.getDistinctRowCount(agg, ImmutableBitSet.of(3), null))
relBuilder.clear()
// $1 > 10
val pred = relBuilder
.push(agg)
.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10))
assertEquals(
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 48.0D, 0.8D),
mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred), 1e-6)
assertEquals(40D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred))
// b > 10 and count(c) > 1 and w$end = 100000
val pred1 = relBuilder
.push(agg)
.and(
relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10)),
relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(1)),
relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(100000))
)
assertEquals(
FlinkRelMdUtil.adaptNdvBasedOnSelectivity(50.0D, 48.0D, 0.12D),
mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), pred1), 1e-6)
assertEquals(40D * 0.15D * 1.0D, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0, 1), pred1))
}
assertEquals(50D,
mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0), null))
assertNull(mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1), null))
assertNull(
mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1), null))
assertEquals(50D,
mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 2), null))
assertNull(
mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1, 2), null))
assertEquals(10D,
mq.getDistinctRowCount(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(3), null))
}
@Test
def testGetDistinctRowCountOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(1.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(), null))
assertEquals(50.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(0), null))
assertEquals(48.0, mq.getDistinctRowCount(agg, ImmutableBitSet.of(1), null))
......
......@@ -222,8 +222,52 @@ class FlinkRelMdPopulationSizeTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetPopulationSizeOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testGetPopulationSizeOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
assertEquals(30D, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
assertEquals(5D, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1)))
assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 2)))
assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(3)))
assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 3)))
assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(1, 3)))
assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(2, 3)))
}
assertEquals(30D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0)))
assertEquals(5D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(1)))
assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(2)))
assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 1)))
assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 2)))
assertEquals(10D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(3)))
assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(0, 3)))
assertEquals(50D, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(1, 3)))
assertEquals(null, mq.getPopulationSize(batchLocalWindowAgg, ImmutableBitSet.of(2, 3)))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
assertEquals(48D, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
assertEquals(10D, mq.getPopulationSize(agg, ImmutableBitSet.of(2)))
assertEquals(null, mq.getPopulationSize(agg, ImmutableBitSet.of(3)))
assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1)))
assertEquals(50D, mq.getPopulationSize(agg, ImmutableBitSet.of(0, 1, 2)))
assertEquals(null, mq.getPopulationSize( agg, ImmutableBitSet.of(0, 1, 3)))
}
assertEquals(50D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0)))
assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(1)))
assertEquals(48D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(2)))
assertEquals(10D, mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(3)))
assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
assertEquals(50D,
mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 2)))
assertNull(mq.getPopulationSize(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1, 3)))
}
@Test
def testGetPopulationSizeOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(1.0, mq.getPopulationSize(agg, ImmutableBitSet.of()))
assertEquals(50.0, mq.getPopulationSize(agg, ImmutableBitSet.of(0)))
assertEquals(48.0, mq.getPopulationSize(agg, ImmutableBitSet.of(1)))
......
......@@ -18,9 +18,15 @@
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.calcite.LogicalWindowAggregate
import org.apache.flink.table.plan.util.FlinkRelMdUtil
import com.google.common.collect.Lists
import org.apache.calcite.rel.core.{AggregateCall, Project}
import org.apache.calcite.rex.RexProgram
import org.apache.calcite.sql.fun.SqlCountAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN
import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
import org.junit.Test
......@@ -135,8 +141,38 @@ class FlinkRelMdRowCountTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetRowCountOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testGetRowCountOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchLocalWindowAgg,
batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg, streamWindowAgg).foreach { agg =>
assertEquals(50D, mq.getRowCount(agg))
}
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchLocalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(50D, mq.getRowCount(agg))
}
relBuilder.clear()
val ts = relBuilder.scan("TemporalTable3").peek()
val aggCallOfWindowAgg = Lists.newArrayList(AggregateCall.create(
new SqlCountAggFunction("COUNT"), false, false, List[Integer](3), -1, 2, ts, null, "s"))
val windowAgg = new LogicalWindowAggregate(
ts.getCluster,
ts.getTraitSet,
ts,
ImmutableBitSet.of(0, 1),
aggCallOfWindowAgg,
tumblingGroupWindow,
namedPropertiesOfWindowAgg)
assertEquals(4000000000D, mq.getRowCount(windowAgg))
}
@Test
def testGetRowCountOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(50.0, mq.getRowCount(agg))
}
}
......
......@@ -20,7 +20,10 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.plan.nodes.calcite.LogicalExpand
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalDataStreamTableScan, FlinkLogicalExpand, FlinkLogicalOverWindow}
import org.apache.flink.table.plan.nodes.logical.{
FlinkLogicalDataStreamTableScan,
FlinkLogicalExpand, FlinkLogicalOverAggregate
}
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecCalc, BatchExecRank}
import org.apache.flink.table.plan.util.ExpandUtil
......@@ -355,7 +358,58 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetSelectivityOnOverWindow(): Unit = {
def testGetSelectivityOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
relBuilder.clear()
relBuilder.push(agg)
// predicate without time fields and aggCall fields
// a > 15
val predicate1 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15))
assertEquals(0.75D, mq.getSelectivity(agg, predicate1))
// predicate with time fields only
// a > 15 and w$end = 1000000
val predicate2 = relBuilder.and(
relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
)
assertEquals(0.75D * 0.15D, mq.getSelectivity(agg, predicate2))
// predicate with time fields and aggCall fields
// a > 15 and count(c) > 100 and w$end = 1000000
val predicate3 = relBuilder.and(
relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(100)),
relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
)
assertEquals(0.75D * 0.15D * 0.01D, mq.getSelectivity(agg, predicate3))
}
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
relBuilder.clear()
relBuilder.push(agg)
// a > 15
val predicate4 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15))
assertEquals(0.8D, mq.getSelectivity(agg, predicate4))
// b > 15
val predicate5 = relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(15))
assertEquals(0.7D, mq.getSelectivity(agg, predicate5))
// a > 15 and b > 15 and count(c) > 100 and w$end = 1000000
val predicate6 = relBuilder.and(
relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(15)),
relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(15)),
relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(100)),
relBuilder.call(EQUALS, relBuilder.field(4), relBuilder.literal(1000000))
)
assertEquals(0.8D * 0.7D * 0.15D * 0.01D, mq.getSelectivity(agg, predicate6))
}
}
@Test
def testGetSelectivityOnOverAgg(): Unit = {
// select a, b, c, d,
// rank() over (partition by c order by d) as rk,
// max(d) over(partition by c order by d) as max_d from MyTable4
......@@ -363,7 +417,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
ImmutableList.of(), -1, longType, "rk")
val maxAggCall = AggregateCall.create(SqlStdOperatorTable.MAX, false,
ImmutableList.of(Integer.valueOf(3)), -1, doubleType, "max_d")
val overWindowGroups = ImmutableList.of(new Window.Group(
val overAggGroups = ImmutableList.of(new Window.Group(
ImmutableBitSet.of(2),
true,
RexWindowBound.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null),
......@@ -383,8 +437,8 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
scan.getRowType.getFieldList.foreach(f => builder.add(f.getName, f.getType))
builder.add(rankAggCall.getName, rankAggCall.getType)
builder.add(maxAggCall.getName, maxAggCall.getType)
val overWindow = new FlinkLogicalOverWindow(cluster, flinkLogicalTraits, scan,
ImmutableList.of(), builder.build(), overWindowGroups)
val overWindow = new FlinkLogicalOverAggregate(cluster, flinkLogicalTraits, scan,
ImmutableList.of(), builder.build(), overAggGroups)
relBuilder.push(overWindow)
// a <= 10
......@@ -401,7 +455,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase {
relBuilder.call(LESS_THAN, relBuilder.field(4), relBuilder.literal(2)))
assertEquals(1 / 25.0 * ((10.0 - 1.0) / (50.0 - 1)) * 0.5, mq.getSelectivity(overWindow, pred3))
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
relBuilder.clear()
relBuilder.push(agg)
......
......@@ -135,13 +135,31 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testAverageColumnSizeOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testAverageColumnSizeOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
assertEquals(Seq(4D, 32D, 8D, 12D, 12D, 12D, 12D), mq.getAverageColumnSizes(agg).toSeq)
}
assertEquals(Seq(4.0, 32.0, 8.0, 8.0),
mq.getAverageColumnSizes(batchLocalWindowAgg).toSeq)
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(Seq(8D, 4D, 8D, 12D, 12D, 12D, 12D), mq.getAverageColumnSizes(agg).toSeq)
}
assertEquals(Seq(8D, 8D, 4D, 8D),
mq.getAverageColumnSizes(batchLocalWindowAggWithAuxGroup).toSeq)
}
@Test
def testAverageColumnSizeOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(Seq(8.0, 7.2, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0),
mq.getAverageColumnSizes(agg).toList)
}
assertEquals(Seq(8.0, 12.0, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0),
mq.getAverageColumnSizes(streamOverWindowAgg).toList)
mq.getAverageColumnSizes(streamOverAgg).toList)
}
@Test
......
......@@ -311,8 +311,47 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetUniqueGroupsOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testGetUniqueGroupsOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg,
batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
assertEquals(ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6)))
assertEquals(ImmutableBitSet.of(3, 4, 5, 6),
mq.getUniqueGroups(agg, ImmutableBitSet.of(3, 4, 5, 6)))
assertEquals(ImmutableBitSet.of(0, 3, 4, 5, 6),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 3, 4, 5, 6)))
assertEquals(ImmutableBitSet.of(0, 1),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
assertEquals(ImmutableBitSet.of(0, 1, 2),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
}
assertEquals(ImmutableBitSet.of(0, 1, 2, 3),
mq.getUniqueGroups(batchLocalWindowAgg, ImmutableBitSet.of(0, 1, 2, 3)))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(ImmutableBitSet.of(1),
mq.getUniqueGroups(agg, ImmutableBitSet.of(1)))
assertEquals(ImmutableBitSet.of(0),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
assertEquals(ImmutableBitSet.of(0, 1, 2),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
assertEquals(ImmutableBitSet.of(0, 1, 2, 3),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3)))
assertEquals(ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6),
mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2, 3, 4, 5, 6)))
}
assertEquals(ImmutableBitSet.of(0),
mq.getUniqueGroups(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1)))
assertEquals(ImmutableBitSet.of(0, 1, 2),
mq.getUniqueGroups(batchLocalWindowAggWithAuxGroup, ImmutableBitSet.of(0, 1, 2)))
}
@Test
def testGetUniqueGroupsOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1)))
assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(agg, ImmutableBitSet.of(0, 1, 2)))
assertEquals(ImmutableBitSet.of(1, 2), mq.getUniqueGroups(agg, ImmutableBitSet.of(1, 2)))
......
......@@ -21,7 +21,7 @@ package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.calcite.LogicalExpand
import org.apache.flink.table.plan.util.ExpandUtil
import com.google.common.collect.ImmutableList
import com.google.common.collect.{ImmutableList, ImmutableSet}
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
......@@ -169,12 +169,32 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase {
}
@Test
def testGetUniqueKeysOnOverWindow(): Unit = {
Array(flinkLogicalOverWindow, batchOverWindowAgg).foreach { agg =>
def testGetUniqueKeysOnWindowAgg(): Unit = {
Array(logicalWindowAgg, flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg,
batchGlobalWindowAggWithLocalAgg).foreach { agg =>
assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 1, 3), ImmutableBitSet.of(0, 1, 4),
ImmutableBitSet.of(0, 1, 5), ImmutableBitSet.of(0, 1, 6)),
mq.getUniqueKeys(agg))
}
assertNull(mq.getUniqueKeys(batchLocalWindowAgg))
Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 3), ImmutableBitSet.of(0, 4),
ImmutableBitSet.of(0, 5), ImmutableBitSet.of(0, 6)),
mq.getUniqueKeys(agg))
}
assertNull(mq.getUniqueKeys(batchLocalWindowAggWithAuxGroup))
}
@Test
def testGetUniqueKeysOnOverAgg(): Unit = {
Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg =>
assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg).toSet)
}
assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverWindowAgg).toSet)
assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverAgg).toSet)
}
@Test
......
......@@ -89,7 +89,9 @@ object MetadataTestUtil {
rootSchema.add("MyTable2", createMyTable2())
rootSchema.add("MyTable3", createMyTable3())
rootSchema.add("MyTable4", createMyTable4())
rootSchema.add("TemporalTable", createTemporalTable())
rootSchema.add("TemporalTable1", createTemporalTable1())
rootSchema.add("TemporalTable2", createTemporalTable2())
rootSchema.add("TemporalTable3", createTemporalTable3())
rootSchema
}
......@@ -213,7 +215,48 @@ object MetadataTestUtil {
getDataStreamTable(schema, new FlinkStatistic(tableStats, uniqueKeys))
}
private def createTemporalTable(): DataStreamTable[BaseRow] = {
private def createTemporalTable1(): DataStreamTable[BaseRow] = {
val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
val fieldTypes = Array[InternalType](
InternalTypes.LONG,
InternalTypes.STRING,
InternalTypes.INT,
InternalTypes.PROCTIME_INDICATOR,
InternalTypes.ROWTIME_INDICATOR)
val colStatsMap = Map[String, ColumnStats](
"a" -> new ColumnStats(30L, 0L, 4D, 4, 45, 5),
"b" -> new ColumnStats(5L, 0L, 32D, 32, null, null),
"c" -> new ColumnStats(48L, 0L, 8D, 8, 50, 0)
)
val tableStats = new TableStats(50L, colStatsMap)
getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats),
producesUpdates = false, isAccRetract = false)
}
private def createTemporalTable2(): DataStreamTable[BaseRow] = {
val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
val fieldTypes = Array[InternalType](
InternalTypes.LONG,
InternalTypes.STRING,
InternalTypes.INT,
InternalTypes.PROCTIME_INDICATOR,
InternalTypes.ROWTIME_INDICATOR)
val colStatsMap = Map[String, ColumnStats](
"a" -> new ColumnStats(50L, 0L, 8D, 8, 55, 5),
"b" -> new ColumnStats(5L, 0L, 16D, 32, null, null),
"c" -> new ColumnStats(48L, 0L, 4D, 4, 50, 0)
)
val tableStats = new TableStats(50L, colStatsMap)
val uniqueKeys = Set(Set("a").asJava).asJava
getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats, uniqueKeys),
producesUpdates = false, isAccRetract = false)
}
private def createTemporalTable3(): DataStreamTable[BaseRow] = {
val fieldNames = Array("a", "b", "c", "proctime", "rowtime")
val fieldTypes = Array[InternalType](
InternalTypes.INT,
......@@ -228,7 +271,7 @@ object MetadataTestUtil {
"c" -> new ColumnStats(null, 0L, 18.6, 64, null, null)
)
val tableStats = new TableStats(20000000L, colStatsMap)
val tableStats = new TableStats(4000000000L, colStatsMap)
getDataStreamTable(fieldNames, fieldTypes, new FlinkStatistic(tableStats),
producesUpdates = false, isAccRetract = false)
}
......
......@@ -27,7 +27,7 @@ import org.apache.flink.table.util.TableTestBase
import org.junit.Assert.assertEquals
import org.junit.Test
class OverWindowAggregateTest extends TableTestBase {
class OverAggregateTest extends TableTestBase {
private val util = streamTestUtil()
util.addDataStream[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime, 'rowtime)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册