提交 07f1b035 编写于 作者: S sunjincheng121 提交者: Fabian Hueske

[FLINK-6257] [table] Consistent naming of ProcessFunction and methods for OVER windows.

- Add check for sort order of OVER windows.

This closes #3681.
上级 5ff9c99f
......@@ -17,20 +17,21 @@
*/
package org.apache.flink.table.plan.nodes.datastream
import java.util.{List => JList}
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.{StreamTableEnvironment, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.table.plan.nodes.OverAggregate
import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.types.Row
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.Window.Group
import java.util.{List => JList}
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.table.codegen.CodeGenerator
......@@ -90,12 +91,20 @@ class DataStreamOverAggregate(
val overWindow: org.apache.calcite.rel.core.Window.Group = logicWindow.groups.get(0)
val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
val orderKeys = overWindow.orderKeys.getFieldCollations
if (overWindow.orderKeys.getFieldCollations.size() != 1) {
if (orderKeys.size() != 1) {
throw new TableException(
"Unsupported use of OVER windows. The window may only be ordered by a single time column.")
"Unsupported use of OVER windows. The window can only be ordered by a single time column.")
}
val orderKey = orderKeys.get(0)
if (!orderKey.direction.equals(ASCENDING)) {
throw new TableException(
"Unsupported use of OVER windows. The window can only be ordered in ASCENDING mode.")
}
val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
val generator = new CodeGenerator(
tableEnv.getConfig,
......@@ -104,78 +113,69 @@ class DataStreamOverAggregate(
val timeType = inputType
.getFieldList
.get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex)
.get(orderKey.getFieldIndex)
.getValue
timeType match {
case _: ProcTimeType =>
// proc-time OVER window
if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// unbounded preceding OVER window
createUnboundedAndCurrentRowProcessingTimeOverWindow(
// unbounded OVER window
createUnboundedAndCurrentRowOverWindow(
generator,
inputDS)
inputDS,
isRowTimeType = false,
isRowsClause = overWindow.isRows)
} else if (
overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
overWindow.upperBound.isCurrentRow) {
overWindow.upperBound.isCurrentRow) {
// bounded OVER window
if (overWindow.isRows) {
// ROWS clause bounded OVER window
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRangeClause = false,
isRowTimeType = false)
} else {
// RANGE clause bounded OVER window
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRangeClause = true,
isRowTimeType = false)
}
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRowTimeType = false,
isRowsClause = overWindow.isRows
)
} else {
throw new TableException(
"processing-time OVER RANGE FOLLOWING window is not supported yet.")
"OVER RANGE FOLLOWING windows are not supported yet.")
}
case _: RowTimeType =>
// row-time OVER window
if (overWindow.lowerBound.isPreceding &&
overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// ROWS/RANGE clause unbounded OVER window
createUnboundedAndCurrentRowEventTimeOverWindow(
overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// unbounded OVER window
createUnboundedAndCurrentRowOverWindow(
generator,
inputDS,
overWindow.isRows)
isRowTimeType = true,
isRowsClause = overWindow.isRows
)
} else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) {
// bounded OVER window
if (overWindow.isRows) {
// ROWS clause bounded OVER window
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRangeClause = false,
isRowTimeType = true)
} else {
// RANGE clause bounded OVER window
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRangeClause = true,
isRowTimeType = true)
}
createBoundedAndCurrentRowOverWindow(
generator,
inputDS,
isRowTimeType = true,
isRowsClause = overWindow.isRows
)
} else {
throw new TableException(
"row-time OVER RANGE FOLLOWING window is not supported yet.")
"OVER RANGE FOLLOWING windows are not supported yet.")
}
case _ =>
throw new TableException(s"Unsupported time type {$timeType}")
throw new TableException(
"Unsupported time type {$timeType}. " +
"OVER windows do only support RowTimeType and ProcTimeType.")
}
}
def createUnboundedAndCurrentRowProcessingTimeOverWindow(
def createUnboundedAndCurrentRowOverWindow(
generator: CodeGenerator,
inputDS: DataStream[Row]): DataStream[Row] = {
inputDS: DataStream[Row],
isRowTimeType: Boolean,
isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
......@@ -184,14 +184,17 @@ class DataStreamOverAggregate(
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val processFunction = AggregateUtil.createUnboundedOverProcessFunction(
generator,
namedAggregates,
inputType,
isRowTimeType,
partitionKeys.nonEmpty,
isRowsClause)
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
generator,
namedAggregates,
inputType)
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
......@@ -201,17 +204,19 @@ class DataStreamOverAggregate(
}
// non-partitioned aggregation
else {
val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
generator,
namedAggregates,
inputType,
isPartitioned = false)
inputDS
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
if (isRowTimeType) {
inputDS.keyBy(new NullByteKeySelector[Row])
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
} else {
inputDS
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
}
result
}
......@@ -219,15 +224,15 @@ class DataStreamOverAggregate(
def createBoundedAndCurrentRowOverWindow(
generator: CodeGenerator,
inputDS: DataStream[Row],
isRangeClause: Boolean,
isRowTimeType: Boolean): DataStream[Row] = {
isRowTimeType: Boolean,
isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
val precedingOffset =
getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRangeClause) 0 else 1)
getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0)
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
......@@ -237,8 +242,9 @@ class DataStreamOverAggregate(
namedAggregates,
inputType,
precedingOffset,
isRangeClause,
isRowTimeType)
isRowsClause,
isRowTimeType
)
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
......@@ -253,49 +259,7 @@ class DataStreamOverAggregate(
else {
inputDS
.keyBy(new NullByteKeySelector[Row])
.process(processFunction)
.setParallelism(1)
.setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
result
}
def createUnboundedAndCurrentRowEventTimeOverWindow(
generator: CodeGenerator,
inputDS: DataStream[Row],
isRows: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val processFunction = AggregateUtil.createUnboundedEventTimeOverProcessFunction(
generator,
namedAggregates,
inputType,
isRows)
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
inputDS.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
// global non-partitioned aggregation
else {
inputDS.keyBy(new NullByteKeySelector[Row])
.process(processFunction)
.setParallelism(1)
.setMaxParallelism(1)
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
......
......@@ -55,21 +55,23 @@ object AggregateUtil {
type JavaList[T] = java.util.List[T]
/**
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final
* aggregate value.
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER
* window to evaluate final aggregate value.
*
* @param generator code generator instance
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param isPartitioned Flag to indicate whether the input is partitioned or not
*
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
* @param inputType Input row type
* @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
* @param isPartitioned It is a tag that indicate whether the input is partitioned
* @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
*/
private[flink] def createUnboundedProcessingOverProcessFunction(
generator: CodeGenerator,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
private[flink] def createUnboundedOverProcessFunction(
generator: CodeGenerator,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
isRowTimeType: Boolean,
isPartitioned: Boolean,
isRowsClause: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
......@@ -95,14 +97,30 @@ object AggregateUtil {
outputArity
)
if (isPartitioned) {
new UnboundedProcessingOverProcessFunction(
genFunction,
aggregationStateType)
if (isRowTimeType) {
if (isRowsClause) {
// ROWS unbounded over process function
new RowTimeUnboundedRowsOver(
genFunction,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
} else {
// RANGE unbounded over process function
new RowTimeUnboundedRangeOver(
genFunction,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
}
} else {
new UnboundedNonPartitionedProcessingOverProcessFunction(
genFunction,
aggregationStateType)
if (isPartitioned) {
new ProcTimeUnboundedPartitionedOver(
genFunction,
aggregationStateType)
} else {
new ProcTimeUnboundedNonPartitionedOver(
genFunction,
aggregationStateType)
}
}
}
......@@ -114,7 +132,7 @@ object AggregateUtil {
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param precedingOffset the preceding offset
* @param isRangeClause It is a tag that indicates whether the OVER clause is rangeClause
* @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
* @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
......@@ -123,7 +141,7 @@ object AggregateUtil {
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
precedingOffset: Long,
isRangeClause: Boolean,
isRowsClause: Boolean,
isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
......@@ -151,15 +169,15 @@ object AggregateUtil {
)
if (isRowTimeType) {
if (isRangeClause) {
new RangeClauseBoundedOverProcessFunction(
if (isRowsClause) {
new RowTimeBoundedRowsOver(
genFunction,
aggregationStateType,
inputRowType,
precedingOffset
)
} else {
new RowsClauseBoundedOverProcessFunction(
new RowTimeBoundedRangeOver(
genFunction,
aggregationStateType,
inputRowType,
......@@ -167,14 +185,14 @@ object AggregateUtil {
)
}
} else {
if (isRangeClause) {
new BoundedProcessingOverRangeProcessFunction(
if (isRowsClause) {
new ProcTimeBoundedRowsOver(
genFunction,
precedingOffset,
aggregationStateType,
inputRowType)
} else {
new BoundedProcessingOverRowProcessFunction(
new ProcTimeBoundedRangeOver(
genFunction,
precedingOffset,
aggregationStateType,
......@@ -183,58 +201,6 @@ object AggregateUtil {
}
}
/**
* Create an [[ProcessFunction]] to evaluate final aggregate value.
*
* @param generator code generator instance
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param isRows Flag to indicate if whether this is a Row (true) or a Range (false)
* over window process
* @return [[UnboundedEventTimeOverProcessFunction]]
*/
private[flink] def createUnboundedEventTimeOverProcessFunction(
generator: CodeGenerator,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
isRows: Boolean): UnboundedEventTimeOverProcessFunction = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetraction = false)
val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates)
val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray
val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray
val outputArity = inputType.getFieldCount + aggregates.length
val genFunction = generator.generateAggregations(
"UnboundedEventTimeOverAggregateHelper",
generator,
inputType,
aggregates,
aggFields,
aggMapping,
forwardMapping,
outputArity)
if (isRows) {
// ROWS unbounded over process function
new UnboundedEventTimeRowsOverProcessFunction(
genFunction,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
} else {
// RANGE unbounded over process function
new UnboundedEventTimeRangeOverProcessFunction(
genFunction,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
}
}
/**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
......
......@@ -43,14 +43,13 @@ import org.slf4j.LoggerFactory
* @param aggregatesTypeInfo row type info of aggregation
* @param inputType row type info of input row
*/
class BoundedProcessingOverRangeProcessFunction(
class ProcTimeBoundedRangeOver(
genAggregations: GeneratedAggregationsFunction,
precedingTimeBoundary: Long,
aggregatesTypeInfo: RowTypeInfo,
inputType: TypeInformation[Row])
extends ProcessFunction[Row, Row]
with Compiler[GeneratedAggregations] {
private var output: Row = _
private var accumulatorState: ValueState[Row] = _
private var rowMapState: MapState[Long, JList[Row]] = _
......
......@@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory
* @param aggregatesTypeInfo row type info of aggregation
* @param inputType row type info of input row
*/
class BoundedProcessingOverRowProcessFunction(
class ProcTimeBoundedRowsOver(
genAggregations: GeneratedAggregationsFunction,
precedingOffset: Long,
aggregatesTypeInfo: RowTypeInfo,
......
......@@ -34,7 +34,7 @@ import org.slf4j.LoggerFactory
* @param genAggregations Generated aggregate helper function
* @param aggregationStateType row type info of aggregation
*/
class UnboundedNonPartitionedProcessingOverProcessFunction(
class ProcTimeUnboundedNonPartitionedOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo)
extends ProcessFunction[Row, Row]
......
......@@ -33,7 +33,7 @@ import org.slf4j.LoggerFactory
* @param genAggregations Generated aggregate helper function
* @param aggregationStateType row type info of aggregation
*/
class UnboundedProcessingOverProcessFunction(
class ProcTimeUnboundedPartitionedOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo)
extends ProcessFunction[Row, Row]
......
......@@ -37,14 +37,13 @@ import org.slf4j.LoggerFactory
* @param inputRowType row type info of input row
* @param precedingOffset preceding offset
*/
class RangeClauseBoundedOverProcessFunction(
class RowTimeBoundedRangeOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo,
inputRowType: RowTypeInfo,
precedingOffset: Long)
extends ProcessFunction[Row, Row]
with Compiler[GeneratedAggregations] {
Preconditions.checkNotNull(aggregationStateType)
Preconditions.checkNotNull(precedingOffset)
......
......@@ -38,7 +38,7 @@ import org.slf4j.LoggerFactory
* @param inputRowType row type info of input row
* @param precedingOffset preceding offset
*/
class RowsClauseBoundedOverProcessFunction(
class RowTimeBoundedRowsOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo,
inputRowType: RowTypeInfo,
......
......@@ -39,7 +39,7 @@ import org.slf4j.LoggerFactory
* @param intermediateType the intermediate row tye which the state saved
* @param inputType the input row tye which the state saved
*/
abstract class UnboundedEventTimeOverProcessFunction(
abstract class RowTimeUnboundedOver(
genAggregations: GeneratedAggregationsFunction,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
......@@ -214,11 +214,11 @@ abstract class UnboundedEventTimeOverProcessFunction(
* A ProcessFunction to support unbounded ROWS window.
* The ROWS clause defines on a physical level how many rows are included in a window frame.
*/
class UnboundedEventTimeRowsOverProcessFunction(
class RowTimeUnboundedRowsOver(
genAggregations: GeneratedAggregationsFunction,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
extends UnboundedEventTimeOverProcessFunction(
extends RowTimeUnboundedOver(
genAggregations: GeneratedAggregationsFunction,
intermediateType,
inputType) {
......@@ -252,11 +252,11 @@ class UnboundedEventTimeRowsOverProcessFunction(
* The RANGE option includes all the rows within the window frame
* that have the same ORDER BY values as the current row.
*/
class UnboundedEventTimeRangeOverProcessFunction(
class RowTimeUnboundedRangeOver(
genAggregations: GeneratedAggregationsFunction,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
extends UnboundedEventTimeOverProcessFunction(
extends RowTimeUnboundedOver(
genAggregations: GeneratedAggregationsFunction,
intermediateType,
inputType) {
......
......@@ -165,7 +165,7 @@ class BoundedProcessingOverRangeProcessFunctionTest {
val genAggFunction = GeneratedAggregationsFunction(funcName, funcCode)
val processFunction = new KeyedProcessOperator[String, Row, Row](
new BoundedProcessingOverRangeProcessFunction(
new ProcTimeBoundedRangeOver(
genAggFunction,
1000,
aggregationStateType,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册