提交 7a9d39fe 编写于 作者: 金竹 提交者: Fabian Hueske

[FLINK-5990] [table] Add event-time OVER ROWS BETWEEN x PRECEDING aggregation to SQL.

This closes #3585.
上级 6949c8c7
......@@ -18,12 +18,15 @@
package org.apache.flink.table.plan.nodes
import org.apache.calcite.rel.RelFieldCollation
import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl}
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rex.{RexInputRef}
import org.apache.flink.table.runtime.aggregate.AggregateUtil._
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
import scala.collection.JavaConverters._
trait OverAggregate {
......@@ -46,8 +49,16 @@ trait OverAggregate {
orderingString
}
private[flink] def windowRange(overWindow: Group): String = {
s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
private[flink] def windowRange(
logicWindow: Window,
overWindow: Group,
input: RelNode): String = {
if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) {
s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " +
s"AND ${overWindow.upperBound}"
} else {
s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
}
}
private[flink] def aggregationToString(
......@@ -92,4 +103,18 @@ trait OverAggregate {
}.mkString(", ")
}
private[flink] def getLowerBoundary(
logicWindow: Window,
overWindow: Group,
input: RelNode): Long = {
val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex;
val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
lowerBound match {
case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue()
case _ => lowerBound.asInstanceOf[Long]
}
}
}
......@@ -32,6 +32,7 @@ import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.Window.Group
import java.util.{List => JList}
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
......@@ -70,9 +71,9 @@ class DataStreamOverAggregate(
super.explainTerms(pw)
.itemIf("partitionBy", partitionToString(inputType, partitionKeys), partitionKeys.nonEmpty)
.item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations))
.itemIf("rows", windowRange(overWindow), overWindow.isRows)
.itemIf("range", windowRange(overWindow), !overWindow.isRows)
.item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations))
.itemIf("rows", windowRange(logicWindow, overWindow, getInput), overWindow.isRows)
.itemIf("range", windowRange(logicWindow, overWindow, getInput), !overWindow.isRows)
.item(
"select", aggregationToString(
inputType,
......@@ -99,20 +100,58 @@ class DataStreamOverAggregate(
.getFieldList
.get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex)
.getValue
timeType match {
case _: ProcTimeType =>
// both ROWS and RANGE clause with UNBOUNDED PRECEDING and CURRENT ROW condition.
if (overWindow.lowerBound.isUnbounded &&
overWindow.upperBound.isCurrentRow) {
// proc-time OVER window
if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// non-bounded OVER window
createUnboundedAndCurrentRowProcessingTimeOverWindow(inputDS)
} else if (
overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
overWindow.upperBound.isCurrentRow) {
// bounded OVER window
if (overWindow.isRows) {
// ROWS clause bounded OVER window
throw new TableException(
"ROWS clause bounded proc-time OVER window no supported yet.")
} else {
// RANGE clause bounded OVER window
throw new TableException(
"RANGE clause bounded proc-time OVER window no supported yet.")
}
} else {
throw new TableException(
"OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " +
"condition.")
"OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " +
"condition.")
}
case _: RowTimeType =>
throw new TableException("OVER Window of the EventTime type is not currently supported.")
// row-time OVER window
if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// non-bounded OVER window
if (overWindow.isRows) {
// ROWS clause unbounded OVER window
throw new TableException(
"ROWS clause unbounded row-time OVER window no supported yet.")
} else {
// RANGE clause unbounded OVER window
throw new TableException(
"RANGE clause unbounded row-time OVER window no supported yet.")
}
} else if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
overWindow.upperBound.isCurrentRow) {
// bounded OVER window
if (overWindow.isRows) {
// ROWS clause bounded OVER window
createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, true)
} else {
// RANGE clause bounded OVER window
throw new TableException(
"RANGE clause bounded row-time OVER window no supported yet.")
}
} else {
throw new TableException(
"row-time OVER window only support CURRENT ROW condition.")
}
case _ =>
throw new TableException(s"Unsupported time type {$timeType}")
}
......@@ -120,7 +159,7 @@ class DataStreamOverAggregate(
}
def createUnboundedAndCurrentRowProcessingTimeOverWindow(
inputDS: DataStream[Row]): DataStream[Row] = {
inputDS: DataStream[Row]): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
......@@ -130,32 +169,78 @@ class DataStreamOverAggregate(
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
namedAggregates,
inputType)
// partitioned aggregation
if (partitionKeys.nonEmpty) {
val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
namedAggregates,
inputType)
inputDS
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
// non-partitioned aggregation
else {
val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
namedAggregates,
inputType,
false)
inputDS
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
}
// non-partitioned aggregation
else {
val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
namedAggregates,
inputType,
false)
inputDS
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
result
}
def createRowsClauseBoundedAndCurrentRowOverWindow(
inputDS: DataStream[Row],
isRowTimeType: Boolean = false): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
val inputFields = (0 until inputType.getFieldCount).toArray
val precedingOffset =
getLowerBoundary(logicWindow, overWindow, getInput()) + 1
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val processFunction = AggregateUtil.createRowsClauseBoundedOverProcessFunction(
namedAggregates,
inputType,
inputFields,
precedingOffset,
isRowTimeType
)
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
// non-partitioned aggregation
else {
inputDS
.keyBy(new NullByteKeySelector[Row])
.process(processFunction)
.setParallelism(1)
.setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
result
}
......@@ -180,7 +265,7 @@ class DataStreamOverAggregate(
}
}ORDER BY: ${orderingToString(inputType, overWindow.orderKeys.getFieldCollations)}, " +
s"${if (overWindow.isRows) "ROWS" else "RANGE"}" +
s"${windowRange(overWindow)}, " +
s"${windowRange(logicWindow, overWindow, getInput)}, " +
s"select: (${
aggregationToString(
inputType,
......
......@@ -61,7 +61,7 @@ object AggregateUtil {
* @param isPartitioned Flag to indicate whether the input is partitioned or not
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def CreateUnboundedProcessingOverProcessFunction(
private[flink] def createUnboundedProcessingOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
......@@ -90,6 +90,52 @@ object AggregateUtil {
}
}
/**
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause
* bounded OVER window to evaluate final aggregate value.
*
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param inputFields All input fields
* @param precedingOffset the preceding offset
* @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def createRowsClauseBoundedOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
inputFields: Array[Int],
precedingOffset: Long,
isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetraction = true)
val aggregationStateType: RowTypeInfo =
createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
val inputRowType: RowTypeInfo =
createDataSetAggregateBufferDataType(inputFields, Array(), inputType)
val processFunction = if (isRowTimeType) {
new RowsClauseBoundedOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
inputRowType,
precedingOffset
)
} else {
throw TableException(
"Bounded partitioned proc-time OVER aggregation is not supported yet.")
}
processFunction
}
/**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
* The output of the function contains the grouping keys and the timestamp and the intermediate
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.util
import java.util.{List => JList}
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
/**
* Process Function for ROWS clause event-time bounded OVER window
*
* @param aggregates the list of all [[AggregateFunction]] used for this aggregation
* @param aggFields the position (in the input Row) of the input value for each aggregate
* @param forwardedFieldCount the count of forwarded fields.
* @param aggregationStateType the row type info of aggregation
* @param inputRowType the row type info of input row
* @param precedingOffset the preceding offset
*/
class RowsClauseBoundedOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo,
private val inputRowType: RowTypeInfo,
private val precedingOffset: Long)
extends ProcessFunction[Row, Row] {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
Preconditions.checkNotNull(forwardedFieldCount)
Preconditions.checkNotNull(aggregationStateType)
Preconditions.checkNotNull(precedingOffset)
private var output: Row = _
// the state which keeps the last triggering timestamp
private var lastTriggeringTsState: ValueState[Long] = _
// the state which keeps the count of data
private var dataCountState: ValueState[Long] = _
// the state which used to materialize the accumulator for incremental calculation
private var accumulatorState: ValueState[Row] = _
// the state which keeps all the data that are not expired.
// The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
// the second element of tuple is a list that contains the entire data of all the rows belonging
// to this time stamp.
private var dataState: MapState[Long, JList[Row]] = _
override def open(config: Configuration) {
output = new Row(forwardedFieldCount + aggregates.length)
val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
val dataCountStateDescriptor =
new ValueStateDescriptor[Long]("dataCountState", classOf[Long])
dataCountState = getRuntimeContext.getState(dataCountStateDescriptor)
val accumulatorStateDescriptor =
new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
val keyTypeInformation: TypeInformation[Long] =
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
new MapStateDescriptor[Long, JList[Row]](
"dataState",
keyTypeInformation,
valueTypeInformation)
dataState = getRuntimeContext.getMapState(mapStateDescriptor)
}
override def processElement(
input: Row,
ctx: ProcessFunction[Row, Row]#Context,
out: Collector[Row]): Unit = {
// triggering timestamp for trigger calculation
val triggeringTs = ctx.timestamp
val lastTriggeringTs = lastTriggeringTsState.value
// check if the data is expired, if not, save the data and register event time timer
if (triggeringTs > lastTriggeringTs) {
val data = dataState.get(triggeringTs)
if (null != data) {
data.add(input)
dataState.put(triggeringTs, data)
} else {
val data = new util.ArrayList[Row]
data.add(input)
dataState.put(triggeringTs, data)
// register event time timer
ctx.timerService.registerEventTimeTimer(triggeringTs)
}
}
}
override def onTimer(
timestamp: Long,
ctx: ProcessFunction[Row, Row]#OnTimerContext,
out: Collector[Row]): Unit = {
// gets all window data from state for the calculation
val inputs: JList[Row] = dataState.get(timestamp)
if (null != inputs) {
var accumulators = accumulatorState.value
var dataCount = dataCountState.value
var retractList: JList[Row] = null
var retractTs: Long = Long.MaxValue
var retractCnt: Int = 0
var j = 0
var i = 0
while (j < inputs.size) {
val input = inputs.get(j)
// initialize when first run or failover recovery per key
if (null == accumulators) {
accumulators = new Row(aggregates.length)
i = 0
while (i < aggregates.length) {
accumulators.setField(i, aggregates(i).createAccumulator())
i += 1
}
}
var retractRow: Row = null
if (dataCount >= precedingOffset) {
if (null == retractList) {
// find the smallest timestamp
retractTs = Long.MaxValue
val dataTimestampIt = dataState.keys.iterator
while (dataTimestampIt.hasNext) {
val dataTs = dataTimestampIt.next
if (dataTs < retractTs) {
retractTs = dataTs
}
}
// get the oldest rows to retract them
retractList = dataState.get(retractTs)
}
retractRow = retractList.get(retractCnt)
retractCnt += 1
// remove retracted values from state
if (retractList.size == retractCnt) {
dataState.remove(retractTs)
retractList = null
retractCnt = 0
}
} else {
dataCount += 1
}
// copy forwarded fields to output row
i = 0
while (i < forwardedFieldCount) {
output.setField(i, input.getField(i))
i += 1
}
// retract old row from accumulators
if (null != retractRow) {
i = 0
while (i < aggregates.length) {
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).retract(accumulator, retractRow.getField(aggFields(i)))
i += 1
}
}
// accumulate current row and set aggregate in output row
i = 0
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
j += 1
out.collect(output)
}
// update all states
if (dataState.contains(retractTs)) {
if (retractCnt > 0) {
retractList.subList(0, retractCnt).clear()
dataState.put(retractTs, retractList)
}
}
dataCountState.update(dataCount)
accumulatorState.update(accumulators)
}
lastTriggeringTsState.update(timestamp)
}
}
......@@ -19,14 +19,18 @@
package org.apache.flink.table.api.scala.stream.sql
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.functions.source.SourceFunction
import org.apache.flink.table.api.scala.stream.sql.SqlITCase.EventTimeSourceFunction
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.table.api.{TableEnvironment, TableException}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.stream.utils.{StreamingWithStateTestBase, StreamITCase,
StreamTestData}
import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.types.Row
import org.junit.Assert._
import org.junit._
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext
import scala.collection.mutable
......@@ -293,6 +297,120 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test
def testBoundPartitionedEventTimeWindowWithRow(): Unit = {
val data = Seq(
Left((1L, (1L, 1, "Hello"))),
Left((2L, (2L, 2, "Hello"))),
Left((1L, (1L, 1, "Hello"))),
Left((2L, (2L, 2, "Hello"))),
Left((2L, (2L, 2, "Hello"))),
Left((1L, (1L, 1, "Hello"))),
Left((3L, (7L, 7, "Hello World"))),
Left((1L, (7L, 7, "Hello World"))),
Left((1L, (7L, 7, "Hello World"))),
Right(2L),
Left((3L, (3L, 3, "Hello"))),
Left((4L, (4L, 4, "Hello"))),
Left((5L, (5L, 5, "Hello"))),
Left((6L, (6L, 6, "Hello"))),
Left((20L, (20L, 20, "Hello World"))),
Right(6L),
Left((8L, (8L, 8, "Hello World"))),
Left((7L, (7L, 7, "Hello World"))),
Right(20L))
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.clear
val t1 = env
.addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val sqlQuery = "SELECT " +
"c, a, " +
"count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
", sum(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
" from T1"
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
"Hello,2,3,4", "Hello,2,3,5","Hello,2,3,6",
"Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
"Hello,6,3,15",
"Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
"Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test
def testBoundNonPartitionedEventTimeWindowWithRow(): Unit = {
val data = Seq(
Left((2L, (2L, 2, "Hello"))),
Left((2L, (2L, 2, "Hello"))),
Left((1L, (1L, 1, "Hello"))),
Left((1L, (1L, 1, "Hello"))),
Left((2L, (2L, 2, "Hello"))),
Left((1L, (1L, 1, "Hello"))),
Left((20L, (20L, 20, "Hello World"))), // early row
Right(3L),
Left((2L, (2L, 2, "Hello"))), // late row
Left((3L, (3L, 3, "Hello"))),
Left((4L, (4L, 4, "Hello"))),
Left((5L, (5L, 5, "Hello"))),
Left((6L, (6L, 6, "Hello"))),
Left((7L, (7L, 7, "Hello World"))),
Right(7L),
Left((9L, (9L, 9, "Hello World"))),
Left((8L, (8L, 8, "Hello World"))),
Left((8L, (8L, 8, "Hello World"))),
Right(20L))
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
env.setParallelism(1)
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.clear
val t1 = env
.addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val sqlQuery = "SELECT " +
"c, a, " +
"count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)," +
"sum(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" +
"from T1"
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
"Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
"Hello,3,3,7",
"Hello,4,3,9", "Hello,5,3,12",
"Hello,6,3,15", "Hello World,7,3,18",
"Hello World,8,3,21", "Hello World,8,3,23",
"Hello World,9,3,25",
"Hello World,20,3,37")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
/**
* All aggregates must be computed on the same window.
*/
......@@ -317,4 +435,21 @@ class SqlITCase extends StreamingWithStateTestBase {
result.addSink(new StreamITCase.StringSink)
env.execute()
}
}
object SqlITCase {
class EventTimeSourceFunction[T](
dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] {
override def run(ctx: SourceContext[T]): Unit = {
dataWithTimestampList.foreach {
case Left(t) => ctx.collectWithTimestamp(t._2, t._1)
case Right(w) => ctx.emitWatermark(new Watermark(w))
}
}
override def cancel(): Unit = ???
}
}
......@@ -239,4 +239,59 @@ class WindowAggregateTest extends TableTestBase {
)
streamUtil.verifySql(sql, expected)
}
@Test
def testBoundPartitionedRowTimeWindowWithRow() = {
val sql = "SELECT " +
"c, " +
"count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " +
"CURRENT ROW) as cnt1 " +
"from MyTable"
val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamOverAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "a", "c", "ROWTIME() AS $2")
),
term("partitionBy", "c"),
term("orderBy", "ROWTIME"),
term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
),
term("select", "c", "w0$o0 AS $1")
)
streamUtil.verifySql(sql, expected)
}
@Test
def testBoundNonPartitionedRowTimeWindowWithRow() = {
val sql = "SELECT " +
"c, " +
"count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " +
"CURRENT ROW) as cnt1 " +
"from MyTable"
val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamOverAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "a", "c", "ROWTIME() AS $2")
),
term("orderBy", "ROWTIME"),
term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
),
term("select", "c", "w0$o0 AS $1")
)
streamUtil.verifySql(sql, expected)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册