提交 44f9c76a 编写于 作者: H hongyuhong 00223286 提交者: Fabian Hueske

[FLINK-6200] [table] Add support for unbounded event-time OVER RANGE window.

This closes #3649.
上级 aa3c395b
......@@ -127,14 +127,8 @@ class DataStreamOverAggregate(
// row-time OVER window
if (overWindow.lowerBound.isPreceding &&
overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
if (overWindow.isRows) {
// unbounded preceding OVER ROWS window
createUnboundedAndCurrentRowEventTimeOverWindow(inputDS)
} else {
// unbounded preceding OVER RANGE window
throw new TableException(
"row-time OVER RANGE UNBOUNDED PRECEDING window is not supported yet.")
}
// ROWS/RANGE clause unbounded OVER window
createUnboundedAndCurrentRowEventTimeOverWindow(inputDS, overWindow.isRows)
} else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) {
// bounded OVER window
if (overWindow.isRows) {
......@@ -202,8 +196,8 @@ class DataStreamOverAggregate(
def createBoundedAndCurrentRowOverWindow(
inputDS: DataStream[Row],
isRangeClause: Boolean = false,
isRowTimeType: Boolean = false): DataStream[Row] = {
isRangeClause: Boolean,
isRowTimeType: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
......@@ -247,7 +241,8 @@ class DataStreamOverAggregate(
}
def createUnboundedAndCurrentRowEventTimeOverWindow(
inputDS: DataStream[Row]): DataStream[Row] = {
inputDS: DataStream[Row],
isRows: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
......@@ -258,7 +253,8 @@ class DataStreamOverAggregate(
val processFunction = AggregateUtil.createUnboundedEventTimeOverProcessFunction(
namedAggregates,
inputType)
inputType,
isRows)
val result: DataStream[Row] =
// partitioned aggregation
......
......@@ -152,7 +152,8 @@ object AggregateUtil {
*/
private[flink] def createUnboundedEventTimeOverProcessFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType): UnboundedEventTimeOverProcessFunction = {
inputType: RelDataType,
isRows: Boolean): UnboundedEventTimeOverProcessFunction = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
......@@ -162,12 +163,23 @@ object AggregateUtil {
val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates)
new UnboundedEventTimeOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
if (isRows) {
// ROWS unbounded over process function
new UnboundedEventTimeRowsOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
} else {
// RANGE unbounded over process function
new UnboundedEventTimeRangeOverProcessFunction(
aggregates,
aggFields,
inputType.getFieldCount,
aggregationStateType,
FlinkTypeFactory.toInternalRowTypeInfo(inputType))
}
}
/**
......
......@@ -41,7 +41,7 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
* @param inputType the input row tye which the state saved
*
*/
class UnboundedEventTimeOverProcessFunction(
abstract class UnboundedEventTimeOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val forwardedFieldCount: Int,
......@@ -53,7 +53,7 @@ class UnboundedEventTimeOverProcessFunction(
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
private var output: Row = _
protected var output: Row = _
// state to hold the accumulators of the aggregations
private var accumulatorState: ValueState[Row] = _
// state to hold rows until the next watermark arrives
......@@ -162,30 +162,9 @@ class UnboundedEventTimeOverProcessFunction(
val curRowList = rowMapState.get(curTimestamp)
collector.setAbsoluteTimestamp(curTimestamp)
var j = 0
while (j < curRowList.size) {
val curRow = curRowList.get(j)
i = 0
// copy forwarded fields to output row
while (i < forwardedFieldCount) {
output.setField(i, curRow.getField(i))
i += 1
}
// update accumulators and copy aggregates to output row
i = 0
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
// emit output row
collector.collect(output)
j += 1
}
// process the same timestamp datas, the mechanism is different according ROWS or RANGE
processElementsWithSameTimestamp(curRowList, lastAccumulator, collector)
rowMapState.remove(curTimestamp)
}
......@@ -204,21 +183,145 @@ class UnboundedEventTimeOverProcessFunction(
* If timestamps arrive in order (as in case of using the RocksDB state backend) this is just
* an append with O(1).
*/
private def insertToSortedList(recordTimeStamp: Long) = {
private def insertToSortedList(recordTimestamp: Long) = {
val listIterator = sortedTimestamps.listIterator(sortedTimestamps.size)
var continue = true
while (listIterator.hasPrevious && continue) {
val timestamp = listIterator.previous
if (recordTimeStamp >= timestamp) {
if (recordTimestamp >= timestamp) {
listIterator.next
listIterator.add(recordTimeStamp)
listIterator.add(recordTimestamp)
continue = false
}
}
if (continue) {
sortedTimestamps.addFirst(recordTimeStamp)
sortedTimestamps.addFirst(recordTimestamp)
}
}
/**
* Process the same timestamp datas, the mechanism is different between
* rows and range window.
*/
def processElementsWithSameTimestamp(
curRowList: JList[Row],
lastAccumulator: Row,
out: Collector[Row]): Unit
}
/**
* A ProcessFunction to support unbounded ROWS window.
* The ROWS clause defines on a physical level how many rows are included in a window frame.
*/
class UnboundedEventTimeRowsOverProcessFunction(
aggregates: Array[AggregateFunction[_]],
aggFields: Array[Int],
forwardedFieldCount: Int,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
extends UnboundedEventTimeOverProcessFunction(
aggregates,
aggFields,
forwardedFieldCount,
intermediateType,
inputType) {
override def processElementsWithSameTimestamp(
curRowList: JList[Row],
lastAccumulator: Row,
out: Collector[Row]): Unit = {
var j = 0
var i = 0
while (j < curRowList.size) {
val curRow = curRowList.get(j)
i = 0
// copy forwarded fields to output row
while (i < forwardedFieldCount) {
output.setField(i, curRow.getField(i))
i += 1
}
// update accumulators and copy aggregates to output row
i = 0
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
// emit output row
out.collect(output)
j += 1
}
}
}
/**
* A ProcessFunction to support unbounded RANGE window.
* The RANGE option includes all the rows within the window frame
* that have the same ORDER BY values as the current row.
*/
class UnboundedEventTimeRangeOverProcessFunction(
aggregates: Array[AggregateFunction[_]],
aggFields: Array[Int],
forwardedFieldCount: Int,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
extends UnboundedEventTimeOverProcessFunction(
aggregates,
aggFields,
forwardedFieldCount,
intermediateType,
inputType) {
override def processElementsWithSameTimestamp(
curRowList: JList[Row],
lastAccumulator: Row,
out: Collector[Row]): Unit = {
var j = 0
var i = 0
// all same timestamp data should have same aggregation value.
while (j < curRowList.size) {
val curRow = curRowList.get(j)
i = 0
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)))
i += 1
}
j += 1
}
// emit output row
j = 0
while (j < curRowList.size) {
val curRow = curRowList.get(j)
// copy forwarded fields to output row
i = 0
while (i < forwardedFieldCount) {
output.setField(i, curRow.getField(i))
i += 1
}
//copy aggregates to output row
i = 0
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
out.collect(output)
j += 1
}
}
}
......@@ -840,6 +840,138 @@ class SqlITCase extends StreamingWithStateTestBase {
"6,8,Hello world,51,9,5,9,1")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
/** test sliding event-time non-partitioned unbounded RANGE window **/
@Test
def testUnboundedNonPartitionedEventTimeRangeWindow(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
StreamITCase.testResults = mutable.MutableList()
env.setParallelism(1)
val sqlQuery = "SELECT a, b, c, " +
"SUM(b) over (order by rowtime() range between unbounded preceding and current row), " +
"count(b) over (order by rowtime() range between unbounded preceding and current row), " +
"avg(b) over (order by rowtime() range between unbounded preceding and current row), " +
"max(b) over (order by rowtime() range between unbounded preceding and current row), " +
"min(b) over (order by rowtime() range between unbounded preceding and current row) " +
"from T1"
val data = Seq(
Left(14000005L, (1, 1L, "Hi")),
Left(14000000L, (2, 1L, "Hello")),
Left(14000002L, (1, 1L, "Hello")),
Left(14000002L, (1, 2L, "Hello")),
Left(14000002L, (1, 3L, "Hello world")),
Left(14000003L, (2, 2L, "Hello world")),
Left(14000003L, (2, 3L, "Hello world")),
Right(14000020L),
Left(14000021L, (1, 4L, "Hello world")),
Left(14000022L, (1, 5L, "Hello world")),
Left(14000022L, (1, 6L, "Hello world")),
Left(14000022L, (1, 7L, "Hello world")),
Left(14000023L, (2, 4L, "Hello world")),
Left(14000023L, (2, 5L, "Hello world")),
Right(14000030L)
)
val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"2,1,Hello,1,1,1,1,1",
"1,1,Hello,7,4,1,3,1",
"1,2,Hello,7,4,1,3,1",
"1,3,Hello world,7,4,1,3,1",
"2,2,Hello world,12,6,2,3,1",
"2,3,Hello world,12,6,2,3,1",
"1,1,Hi,13,7,1,3,1",
"1,4,Hello world,17,8,2,4,1",
"1,5,Hello world,35,11,3,7,1",
"1,6,Hello world,35,11,3,7,1",
"1,7,Hello world,35,11,3,7,1",
"2,4,Hello world,44,13,3,7,1",
"2,5,Hello world,44,13,3,7,1"
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
/** test sliding event-time unbounded RANGE window **/
@Test
def testUnboundedPartitionedEventTimeRangeWindow(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
env.setStateBackend(getStateBackend)
StreamITCase.testResults = mutable.MutableList()
env.setParallelism(1)
val sqlQuery = "SELECT a, b, c, " +
"SUM(b) over (" +
"partition by a order by rowtime() range between unbounded preceding and current row), " +
"count(b) over (" +
"partition by a order by rowtime() range between unbounded preceding and current row), " +
"avg(b) over (" +
"partition by a order by rowtime() range between unbounded preceding and current row), " +
"max(b) over (" +
"partition by a order by rowtime() range between unbounded preceding and current row), " +
"min(b) over (" +
"partition by a order by rowtime() range between unbounded preceding and current row) " +
"from T1"
val data = Seq(
Left(14000005L, (1, 1L, "Hi")),
Left(14000000L, (2, 1L, "Hello")),
Left(14000002L, (1, 1L, "Hello")),
Left(14000002L, (1, 2L, "Hello")),
Left(14000002L, (1, 3L, "Hello world")),
Left(14000003L, (2, 2L, "Hello world")),
Left(14000003L, (2, 3L, "Hello world")),
Right(14000020L),
Left(14000021L, (1, 4L, "Hello world")),
Left(14000022L, (1, 5L, "Hello world")),
Left(14000022L, (1, 6L, "Hello world")),
Left(14000022L, (1, 7L, "Hello world")),
Left(14000023L, (2, 4L, "Hello world")),
Left(14000023L, (2, 5L, "Hello world")),
Right(14000030L)
)
val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data))
.toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("T1", t1)
val result = tEnv.sql(sqlQuery).toDataStream[Row]
result.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList(
"1,1,Hello,6,3,2,3,1",
"1,2,Hello,6,3,2,3,1",
"1,3,Hello world,6,3,2,3,1",
"1,1,Hi,7,4,1,3,1",
"2,1,Hello,1,1,1,1,1",
"2,2,Hello world,6,3,2,3,1",
"2,3,Hello world,6,3,2,3,1",
"1,4,Hello world,11,5,2,4,1",
"1,5,Hello world,29,8,3,7,1",
"1,6,Hello world,29,8,3,7,1",
"1,7,Hello world,29,8,3,7,1",
"2,4,Hello world,15,5,3,5,1",
"2,5,Hello world,15,5,3,5,1"
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
}
object SqlITCase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册