提交 d4665a00 编写于 作者: 金竹 提交者: Fabian Hueske

[FLINK-5655] [table] Add event-time OVER RANGE BETWEEN x PRECEDING aggregation to SQL.

This closes #3629.
上级 ca681101
......@@ -139,11 +139,16 @@ class DataStreamOverAggregate(
// bounded OVER window
if (overWindow.isRows) {
// ROWS clause bounded OVER window
createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, isRowTimeType = true)
createBoundedAndCurrentRowOverWindow(
inputDS,
isRangeClause = false,
isRowTimeType = true)
} else {
// RANGE clause bounded OVER window
throw new TableException(
"row-time OVER RANGE PRECEDING window is not supported yet.")
createBoundedAndCurrentRowOverWindow(
inputDS,
isRangeClause = true,
isRowTimeType = true)
}
} else {
throw new TableException(
......@@ -195,8 +200,9 @@ class DataStreamOverAggregate(
result
}
def createRowsClauseBoundedAndCurrentRowOverWindow(
def createBoundedAndCurrentRowOverWindow(
inputDS: DataStream[Row],
isRangeClause: Boolean = false,
isRowTimeType: Boolean = false): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
......@@ -209,10 +215,11 @@ class DataStreamOverAggregate(
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val processFunction = AggregateUtil.createRowsClauseBoundedOverProcessFunction(
val processFunction = AggregateUtil.createBoundedOverProcessFunction(
namedAggregates,
inputType,
precedingOffset,
isRangeClause,
isRowTimeType
)
val result: DataStream[Row] =
......
......@@ -91,20 +91,21 @@ object AggregateUtil {
}
/**
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause
* Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for
* bounded OVER window to evaluate final aggregate value.
*
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param inputFields All input fields
* @param precedingOffset the preceding offset
* @param isRangeClause It is a tag that indicates whether the OVER clause is rangeClause
* @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
private[flink] def createRowsClauseBoundedOverProcessFunction(
private[flink] def createBoundedOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
precedingOffset: Long,
isRangeClause: Boolean,
isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
......@@ -117,14 +118,25 @@ object AggregateUtil {
val inputRowType = FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo]
if (isRowTimeType) {
new RowsClauseBoundedOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
inputRowType,
precedingOffset
)
if (isRangeClause) {
new RangeClauseBoundedOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
inputRowType,
precedingOffset
)
} else {
new RowsClauseBoundedOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
inputRowType,
precedingOffset
)
}
} else {
throw TableException(
"Bounded partitioned proc-time OVER aggregation is not supported yet.")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.util.{List => JList, ArrayList => JArrayList}
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
/**
* Process Function for RANGE clause event-time bounded OVER window
*
* @param aggregates the list of all [[AggregateFunction]] used for this aggregation
* @param aggFields the position (in the input Row) of the input value for each aggregate
* @param forwardedFieldCount the count of forwarded fields.
* @param aggregationStateType the row type info of aggregation
* @param inputRowType the row type info of input row
* @param precedingOffset the preceding offset
*/
class RangeClauseBoundedOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo,
private val inputRowType: RowTypeInfo,
private val precedingOffset: Long)
extends ProcessFunction[Row, Row] {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
Preconditions.checkNotNull(forwardedFieldCount)
Preconditions.checkNotNull(aggregationStateType)
Preconditions.checkNotNull(precedingOffset)
private var output: Row = _
// the state which keeps the last triggering timestamp
private var lastTriggeringTsState: ValueState[Long] = _
// the state which used to materialize the accumulator for incremental calculation
private var accumulatorState: ValueState[Row] = _
// the state which keeps all the data that are not expired.
// The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
// the second element of tuple is a list that contains the entire data of all the rows belonging
// to this time stamp.
private var dataState: MapState[Long, JList[Row]] = _
override def open(config: Configuration) {
output = new Row(forwardedFieldCount + aggregates.length)
val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
val accumulatorStateDescriptor =
new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
val keyTypeInformation: TypeInformation[Long] =
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
new MapStateDescriptor[Long, JList[Row]](
"dataState",
keyTypeInformation,
valueTypeInformation)
dataState = getRuntimeContext.getMapState(mapStateDescriptor)
}
override def processElement(
input: Row,
ctx: ProcessFunction[Row, Row]#Context,
out: Collector[Row]): Unit = {
// triggering timestamp for trigger calculation
val triggeringTs = ctx.timestamp
val lastTriggeringTs = lastTriggeringTsState.value
// check if the data is expired, if not, save the data and register event time timer
if (triggeringTs > lastTriggeringTs) {
val data = dataState.get(triggeringTs)
if (null != data) {
data.add(input)
dataState.put(triggeringTs, data)
} else {
val data = new JArrayList[Row]
data.add(input)
dataState.put(triggeringTs, data)
// register event time timer
ctx.timerService.registerEventTimeTimer(triggeringTs)
}
}
}
override def onTimer(
timestamp: Long,
ctx: ProcessFunction[Row, Row]#OnTimerContext,
out: Collector[Row]): Unit = {
// gets all window data from state for the calculation
val inputs: JList[Row] = dataState.get(timestamp)
if (null != inputs) {
var accumulators = accumulatorState.value
var dataListIndex = 0
var aggregatesIndex = 0
// initialize when first run or failover recovery per key
if (null == accumulators) {
accumulators = new Row(aggregates.length)
aggregatesIndex = 0
while (aggregatesIndex < aggregates.length) {
accumulators.setField(aggregatesIndex, aggregates(aggregatesIndex).createAccumulator())
aggregatesIndex += 1
}
}
// keep up timestamps of retract data
val retractTsList: JList[Long] = new JArrayList[Long]
// do retraction
val dataTimestampIt = dataState.keys.iterator
while (dataTimestampIt.hasNext) {
val dataTs: Long = dataTimestampIt.next()
val offset = timestamp - dataTs
if (offset > precedingOffset) {
val retractDataList = dataState.get(dataTs)
dataListIndex = 0
while (dataListIndex < retractDataList.size()) {
aggregatesIndex = 0
while (aggregatesIndex < aggregates.length) {
val accumulator = accumulators.getField(aggregatesIndex).asInstanceOf[Accumulator]
aggregates(aggregatesIndex)
.retract(accumulator, retractDataList.get(dataListIndex)
.getField(aggFields(aggregatesIndex)))
aggregatesIndex += 1
}
dataListIndex += 1
}
retractTsList.add(dataTs)
}
}
// do accumulation
dataListIndex = 0
while (dataListIndex < inputs.size()) {
// accumulate current row
aggregatesIndex = 0
while (aggregatesIndex < aggregates.length) {
val accumulator = accumulators.getField(aggregatesIndex).asInstanceOf[Accumulator]
aggregates(aggregatesIndex).accumulate(accumulator, inputs.get(dataListIndex)
.getField(aggFields(aggregatesIndex)))
aggregatesIndex += 1
}
dataListIndex += 1
}
// set aggregate in output row
aggregatesIndex = 0
while (aggregatesIndex < aggregates.length) {
val index = forwardedFieldCount + aggregatesIndex
val accumulator = accumulators.getField(aggregatesIndex).asInstanceOf[Accumulator]
output.setField(index, aggregates(aggregatesIndex).getValue(accumulator))
aggregatesIndex += 1
}
// copy forwarded fields to output row and emit output row
dataListIndex = 0
while (dataListIndex < inputs.size()) {
aggregatesIndex = 0
while (aggregatesIndex < forwardedFieldCount) {
output.setField(aggregatesIndex, inputs.get(dataListIndex).getField(aggregatesIndex))
aggregatesIndex += 1
}
out.collect(output)
dataListIndex += 1
}
// remove the data that has been retracted
dataListIndex = 0
while (dataListIndex < retractTsList.size) {
dataState.remove(retractTsList.get(dataListIndex))
dataListIndex += 1
}
// update state
accumulatorState.update(accumulators)
lastTriggeringTsState.update(timestamp)
}
}
}
......@@ -411,6 +411,150 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test
def testBoundPartitionedEventTimeWindowWithRange(): Unit = {
val data = Seq(
Left((1500L, (1L, 15, "Hello"))),
Left((1600L, (1L, 16, "Hello"))),
Left((1000L, (1L, 1, "Hello"))),
Left((2000L, (2L, 2, "Hello"))),
Right(1000L),
Left((2000L, (2L, 2, "Hello"))),
Left((2000L, (2L, 3, "Hello"))),
Left((3000L, (3L, 3, "Hello"))),
Right(2000L),
Left((4000L, (4L, 4, "Hello"))),
Right(3000L),
Left((5000L, (5L, 5, "Hello"))),
Right(5000L),
Left((6000L, (6L, 6, "Hello"))),
Left((6500L, (6L, 65, "Hello"))),
Right(7000L),
Left((9000L, (6L, 9, "Hello"))),
Left((9500L, (6L, 18, "Hello"))),
Left((9000L, (6L, 9, "Hello"))),
Right(10000L),
Left((10000L, (7L, 7, "Hello World"))),
Left((11000L, (7L, 17, "Hello World"))),
Left((11000L, (7L, 77, "Hello World"))),
Right(12000L),
Left((14000L, (7L, 18, "Hello World"))),
Right(14000L),
Left((15000L, (8L, 8, "Hello World"))),
Right(17000L),
Left((20000L, (20L, 20, "Hello World"))),
Right(19000L))
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.clear
val t1 = env
.addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val sqlQuery = "SELECT " +
"c, b, " +
"count(a) OVER (PARTITION BY c ORDER BY RowTime() RANGE BETWEEN INTERVAL '1' SECOND " +
"preceding AND CURRENT ROW)" +
", sum(a) OVER (PARTITION BY c ORDER BY RowTime() RANGE BETWEEN INTERVAL '1' SECOND " +
" preceding AND CURRENT ROW)" +
" from T1"
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3",
"Hello,2,6,9", "Hello,3,6,9","Hello,2,6,9",
"Hello,3,4,9",
"Hello,4,2,7",
"Hello,5,2,9",
"Hello,6,2,11","Hello,65,2,12",
"Hello,9,2,12","Hello,9,2,12","Hello,18,3,18",
"Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7",
"Hello World,8,2,15",
"Hello World,20,1,20")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test
def testBoundNonPartitionedEventTimeWindowWithRange(): Unit = {
val data = Seq(
Left((1500L, (1L, 15, "Hello"))),
Left((1600L, (1L, 16, "Hello"))),
Left((1000L, (1L, 1, "Hello"))),
Left((2000L, (2L, 2, "Hello"))),
Right(1000L),
Left((2000L, (2L, 2, "Hello"))),
Left((2000L, (2L, 3, "Hello"))),
Left((3000L, (3L, 3, "Hello"))),
Right(2000L),
Left((4000L, (4L, 4, "Hello"))),
Right(3000L),
Left((5000L, (5L, 5, "Hello"))),
Right(5000L),
Left((6000L, (6L, 6, "Hello"))),
Left((6500L, (6L, 65, "Hello"))),
Right(7000L),
Left((9000L, (6L, 9, "Hello"))),
Left((9500L, (6L, 18, "Hello"))),
Left((9000L, (6L, 9, "Hello"))),
Right(10000L),
Left((10000L, (7L, 7, "Hello World"))),
Left((11000L, (7L, 17, "Hello World"))),
Left((11000L, (7L, 77, "Hello World"))),
Right(12000L),
Left((14000L, (7L, 18, "Hello World"))),
Right(14000L),
Left((15000L, (8L, 8, "Hello World"))),
Right(17000L),
Left((20000L, (20L, 20, "Hello World"))),
Right(19000L))
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.clear
val t1 = env
.addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val sqlQuery = "SELECT " +
"c, b, " +
"count(a) OVER (ORDER BY RowTime() RANGE BETWEEN INTERVAL '1' SECOND " +
"preceding AND CURRENT ROW)" +
", sum(a) OVER (ORDER BY RowTime() RANGE BETWEEN INTERVAL '1' SECOND " +
" preceding AND CURRENT ROW)" +
" from T1"
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3",
"Hello,2,6,9", "Hello,3,6,9","Hello,2,6,9",
"Hello,3,4,9",
"Hello,4,2,7",
"Hello,5,2,9",
"Hello,6,2,11","Hello,65,2,12",
"Hello,9,2,12","Hello,9,2,12","Hello,18,3,18",
"Hello World,7,4,25", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7",
"Hello World,8,2,15",
"Hello World,20,1,20")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
/**
* All aggregates must be computed on the same window.
*/
......
......@@ -350,4 +350,59 @@ class WindowAggregateTest extends TableTestBase {
streamUtil.verifySql(sql, expected)
}
@Test
def testBoundPartitionedRowTimeWindowWithRange() = {
val sql = "SELECT " +
"c, " +
"count(a) OVER (PARTITION BY c ORDER BY RowTime() " +
"RANGE BETWEEN INTERVAL '1' SECOND preceding AND CURRENT ROW) as cnt1 " +
"from MyTable"
val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamOverAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "a", "c", "ROWTIME() AS $2")
),
term("partitionBy", "c"),
term("orderBy", "ROWTIME"),
term("range", "BETWEEN 1000 PRECEDING AND CURRENT ROW"),
term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
),
term("select", "c", "w0$o0 AS $1")
)
streamUtil.verifySql(sql, expected)
}
@Test
def testBoundNonPartitionedRowTimeWindowWithRange() = {
val sql = "SELECT " +
"c, " +
"count(a) OVER (ORDER BY RowTime() " +
"RANGE BETWEEN INTERVAL '1' SECOND preceding AND CURRENT ROW) as cnt1 " +
"from MyTable"
val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamOverAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "a", "c", "ROWTIME() AS $2")
),
term("orderBy", "ROWTIME"),
term("range", "BETWEEN 1000 PRECEDING AND CURRENT ROW"),
term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
),
term("select", "c", "w0$o0 AS $1")
)
streamUtil.verifySql(sql, expected)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册