提交 88576a0e 编写于 作者: J Jark Wu 提交者: twalthr

[FLINK-4692] [table] Add tumbling group-windows for batch tables

This closes #2938.
上级 fb3761b5
......@@ -1279,7 +1279,8 @@ A session window is defined by using the `Session` class as follows:
Currently the following features are not supported yet:
- Row-count windows on event-time
- Windows on batch tables
- Session windows on batch tables
- Sliding windows on batch tables
SQL
----
......
......@@ -802,9 +802,6 @@ class Table(
* @return A windowed table.
*/
def window(groupWindow: GroupWindow): GroupWindowedTable = {
if (tableEnv.isInstanceOf[BatchTableEnvironment]) {
throw new ValidationException(s"Windows on batch tables are currently not supported.")
}
new GroupWindowedTable(this, Seq(), groupWindow)
}
}
......@@ -872,9 +869,6 @@ class GroupedTable(
* @return A windowed table.
*/
def window(groupWindow: GroupWindow): GroupWindowedTable = {
if (table.tableEnv.isInstanceOf[BatchTableEnvironment]) {
throw new ValidationException(s"Windows on batch tables are currently not supported.")
}
new GroupWindowedTable(table, groupKey, groupWindow)
}
}
......
......@@ -36,7 +36,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
case class UnresolvedFieldReference(name: String) extends Attribute {
override def toString = "\"" + name
override def toString = s"'$name"
override private[flink] def withName(newName: String): Attribute =
UnresolvedFieldReference(newName)
......
......@@ -55,7 +55,21 @@ abstract class EventTimeGroupWindow(
}
}
abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name)
abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name) {
override def validate(tableEnv: TableEnvironment): ValidationResult = {
val valid = super.validate(tableEnv)
if (valid.isFailure) {
return valid
}
tableEnv match {
case b: BatchTableEnvironment => ValidationFailure(
"Window on batch must declare a time attribute over which the query is evaluated.")
case _ =>
ValidationSuccess
}
}
}
// ------------------------------------------------------------------------------------------------
// Tumbling group windows
......@@ -107,9 +121,11 @@ case class EventTimeTumblingGroupWindow(
super.validate(tableEnv)
.orElse(TumblingGroupWindow.validate(tableEnv, size))
.orElse(size match {
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS)
if tableEnv.isInstanceOf[StreamTableEnvironment] =>
ValidationFailure(
"Event-time grouping windows on row intervals are currently not supported.")
"Event-time grouping windows on row intervals in a stream environment " +
"are currently not supported.")
case _ =>
ValidationSuccess
})
......@@ -196,9 +212,11 @@ case class EventTimeSlidingGroupWindow(
super.validate(tableEnv)
.orElse(SlidingGroupWindow.validate(tableEnv, size, slide))
.orElse(size match {
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS)
if tableEnv.isInstanceOf[StreamTableEnvironment] =>
ValidationFailure(
"Event-time grouping windows on row intervals are currently not supported.")
"Event-time grouping windows on row intervals in a stream environment " +
"are currently not supported.")
case _ =>
ValidationSuccess
})
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo}
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.FlinkAggregate
import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _}
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.typeutils.TypeConverter
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
/**
* Flink RelNode which matches along with a LogicalWindowAggregate.
*/
class DataSetWindowAggregate(
window: LogicalWindow,
namedProperties: Seq[NamedWindowProperty],
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputNode: RelNode,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
extends SingleRel(cluster, traitSet, inputNode)
with FlinkAggregate
with DataSetRel {
override def deriveRowType() = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetWindowAggregate(
window,
namedProperties,
cluster,
traitSet,
inputs.get(0),
namedAggregates,
getRowType,
inputType,
grouping)
}
override def toString: String = {
s"Aggregate(${
if (!grouping.isEmpty) {
s"groupBy: (${groupingToString(inputType, grouping)}), "
} else {
""
}
}window: ($window), " +
s"select: (${
aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
}))"
}
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty)
.item("window", window)
.item(
"select", aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties))
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
val aggCnt = this.namedAggregates.size
planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize)
}
override def translateToPlan(
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
val config = tableEnv.getConfig
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(
tableEnv,
// tell the input operator that this operator currently only supports Rows as input
Some(TypeConverter.DEFAULT_ROW_TYPE))
// whether identifiers are matched case-sensitively
val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive()
val result = window match {
case EventTimeTumblingGroupWindow(_, _, size) =>
createEventTimeTumblingWindowDataSet(
inputDS,
isTimeInterval(size.resultType),
caseSensitive)
case EventTimeSessionGroupWindow(_, _, _) =>
throw new UnsupportedOperationException(
"Event-time session windows in a batch environment are currently not supported")
case EventTimeSlidingGroupWindow(_, _, _, _) =>
throw new UnsupportedOperationException(
"Event-time sliding windows in a batch environment are currently not supported")
case _: ProcessingTimeGroupWindow =>
throw new UnsupportedOperationException(
"Processing-time tumbling windows are not supported in a batch environment, " +
"windows in a batch environment must declare a time attribute over which " +
"the query is evaluated.")
}
// if the expected type is not a Row, inject a mapper to convert to the expected type
expectedType match {
case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
val mapName = s"convert: (${getRowType.getFieldNames.toList.mkString(", ")})"
result.map(
getConversionMapper(
config = config,
nullableInput = false,
inputType = resultRowTypeInfo.asInstanceOf[TypeInformation[Any]],
expectedType = expectedType.get,
conversionOperatorName = "DataSetWindowAggregateConversion",
fieldNames = getRowType.getFieldNames
))
.name(mapName)
case _ => result
}
}
private def createEventTimeTumblingWindowDataSet(
inputDS: DataSet[Any],
isTimeWindow: Boolean,
isParserCaseSensitive: Boolean)
: DataSet[Any] = {
val mapFunction = createDataSetWindowPrepareMapFunction(
window,
namedAggregates,
grouping,
inputType,
isParserCaseSensitive)
val groupReduceFunction = createDataSetWindowAggGroupReduceFunction(
window,
namedAggregates,
inputType,
getRowType,
grouping,
namedProperties)
val mappedInput = inputDS
.map(mapFunction)
.name(prepareOperatorName)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
if (isTimeWindow) {
// grouped time window aggregation
// group by grouping keys and rowtime field (the last field in the row)
val groupingKeys = grouping.indices ++ Seq(mapReturnType.getArity - 1)
mappedInput.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeys: _*)
.reduceGroup(groupReduceFunction)
.returns(resultRowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Any]]
} else {
// count window
val groupingKeys = grouping.indices.toArray
if (groupingKeys.length > 0) {
// grouped aggregation
mappedInput.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeys: _*)
// sort on time field, it's the last element in the row
.sortGroup(mapReturnType.getArity - 1, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(resultRowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Any]]
} else {
// TODO: count tumbling all window on event-time should sort all the data set
// on event time before applying the windowing logic.
throw new UnsupportedOperationException(
"Count tumbling non-grouping window on event-time are currently not supported.")
}
}
}
private def prepareOperatorName: String = {
val aggString = aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
s"prepare select: ($aggString)"
}
private def aggregateOperatorName: String = {
val aggString = aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
if (grouping.length > 0) {
s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), select: ($aggString)"
} else {
s"window: ($window), select: ($aggString)"
}
}
private def resultRowTypeInfo: RowTypeInfo = {
// get the output types
val fieldTypes: Array[TypeInformation[_]] = getRowType.getFieldList
.map(field => FlinkTypeFactory.toTypeInfo(field.getType))
.toArray
new RowTypeInfo(fieldTypes: _*)
}
}
......@@ -97,6 +97,7 @@ object FlinkRuleSets {
CalcMergeRule.INSTANCE,
// translate to Flink DataSet nodes
DataSetWindowAggregateRule.INSTANCE,
DataSetAggregateRule.INSTANCE,
DataSetAggregateWithNullValuesRule.INSTANCE,
DataSetCalcRule.INSTANCE,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.rules.dataSet
import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate
import org.apache.flink.table.plan.nodes.dataset.{DataSetConvention, DataSetWindowAggregate}
import scala.collection.JavaConversions._
class DataSetWindowAggregateRule
extends ConverterRule(
classOf[LogicalWindowAggregate],
Convention.NONE,
DataSetConvention.INSTANCE,
"DataSetWindowAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate]
// check if we have distinct aggregates
val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
if (distinctAggs) {
throw TableException("DISTINCT aggregates are currently not supported.")
}
// check if we have grouping sets
val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet
if (groupSets || agg.indicator) {
throw TableException("GROUPING SETS are currently not supported.")
}
!distinctAggs && !groupSets && !agg.indicator
}
override def convert(rel: RelNode): RelNode = {
val agg: LogicalWindowAggregate = rel.asInstanceOf[LogicalWindowAggregate]
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
val convInput: RelNode = RelOptRule.convert(agg.getInput, DataSetConvention.INSTANCE)
new DataSetWindowAggregate(
agg.getWindow,
agg.getNamedProperties,
rel.getCluster,
traitSet,
convInput,
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
agg.getGroupSet.toArray)
}
}
object DataSetWindowAggregateRule {
val INSTANCE: RelOptRule = new DataSetWindowAggregateRule
}
......@@ -23,7 +23,7 @@ import org.apache.calcite.rel.`type`._
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName}
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.fun._
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
......@@ -31,12 +31,13 @@ import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
import FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.expressions.{WindowEnd, WindowStart}
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
import org.apache.flink.table.api.TableException
import org.apache.flink.table.api.{TableException, Types}
import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo}
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
......@@ -87,6 +88,151 @@ object AggregateUtil {
mapFunction
}
/**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
* The output of the function contains the grouping keys and the timestamp and the intermediate
* aggregate values of all aggregate function. The timestamp field is aligned to time window
* start and used to be a grouping key in case of time window. In case of count window on
* event-time, the timestamp is not aligned and used to sort.
*
* The output is stored in Row by the following format:
*
* {{{
* avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5
* | |
* v v
* +---------+---------+--------+--------+--------+--------+--------+
* |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | rowtime|
* +---------+---------+--------+--------+--------+--------+--------+
* ^ ^
* | |
* sum(y) aggOffsetInRow = 4 rowtime to group or sort
* }}}
*
* NOTE: this function is only used for time based window on batch tables.
*/
def createDataSetWindowPrepareMapFunction(
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
groupings: Array[Int],
inputType: RelDataType,
isParserCaseSensitive: Boolean): MapFunction[Any, Row] = {
val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
groupings.length)
val mapReturnType: RowTypeInfo =
createAggregateBufferDataType(groupings, aggregates, inputType, Some(Types.LONG))
val (timeFieldPos, tumbleTimeWindowSize) = window match {
case EventTimeTumblingGroupWindow(_, time, size) =>
val timeFieldPos = getTimeFieldPosition(time, inputType, isParserCaseSensitive)
size match {
case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
(timeFieldPos, Some(value))
case _ => (timeFieldPos, None)
}
case EventTimeSessionGroupWindow(_, time, _) =>
(getTimeFieldPosition(time, inputType, isParserCaseSensitive), None)
case _ =>
throw new UnsupportedOperationException(s"$window is currently not supported on batch")
}
new DataSetWindowAggregateMapFunction(
aggregates,
aggFieldIndexes,
groupings,
timeFieldPos,
tumbleTimeWindowSize,
mapReturnType).asInstanceOf[MapFunction[Any, Row]]
}
/**
* Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute window
* aggregates on batch tables. If all aggregates support partial aggregation and is a time
* window, the [[org.apache.flink.api.common.functions.GroupReduceFunction]] implements
* [[org.apache.flink.api.common.functions.CombineFunction]] as well.
*
* NOTE: this function is only used for window on batch tables.
*/
def createDataSetWindowAggGroupReduceFunction(
window: LogicalWindow,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
outputType: RelDataType,
groupings: Array[Int],
properties: Seq[NamedWindowProperty]): RichGroupReduceFunction[Row, Row] = {
val aggregates = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
groupings.length)._2
// the addition one field is used to store time attribute
val intermediateRowArity = groupings.length +
aggregates.map(_.intermediateDataType.length).sum + 1
// the mapping relation between field index of intermediate aggregate Row and output Row.
val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings)
// the mapping relation between aggregate function index in list and its corresponding
// field index in output Row.
val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType)
if (groupingOffsetMapping.length != groupings.length ||
aggOffsetMapping.length != namedAggregates.length) {
throw new TableException(
"Could not find output field in input data type " +
"or aggregate functions.")
}
window match {
case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
// tumbling time window
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
if (aggregates.forall(_.supportPartial)) {
// for incremental aggregations
new DataSetTumbleTimeWindowAggReduceCombineFunction(
intermediateRowArity - 1,
asLong(size),
startPos,
endPos,
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
intermediateRowArity,
outputType.getFieldCount)
}
else {
// for non-incremental aggregations
new DataSetTumbleTimeWindowAggReduceGroupFunction(
intermediateRowArity - 1,
asLong(size),
startPos,
endPos,
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
intermediateRowArity,
outputType.getFieldCount)
}
case EventTimeTumblingGroupWindow(_, _, size) =>
// tumbling count window
new DataSetTumbleCountWindowAggReduceGroupFunction(
asLong(size),
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
intermediateRowArity,
outputType.getFieldCount)
case _ =>
throw new UnsupportedOperationException(s"$window is currently not supported on batch")
}
}
/**
* Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute aggregates.
* If all aggregates support partial aggregation, the
......@@ -360,7 +506,7 @@ object AggregateUtil {
}
}
private def computeWindowStartEndPropertyPos(
private[flink] def computeWindowStartEndPropertyPos(
properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = {
val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) {
......@@ -517,22 +663,20 @@ object AggregateUtil {
private def createAggregateBufferDataType(
groupings: Array[Int],
aggregates: Array[Aggregate[_]],
inputType: RelDataType): RowTypeInfo = {
inputType: RelDataType,
windowKeyType: Option[TypeInformation[_]] = None): RowTypeInfo = {
// get the field data types of group keys.
val groupingTypes: Seq[TypeInformation[_]] = groupings
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toTypeInfo)
val aggPartialNameSuffix = "agg_buffer_"
val factory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
// get all field data types of all intermediate aggregates
val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType)
// concat group key types and aggregation types
val allFieldTypes = groupingTypes ++: aggTypes
val partialType = new RowTypeInfo(allFieldTypes: _*)
// concat group key types and aggregation types, and window key types (may be empty)
val allFieldTypes = groupingTypes ++: aggTypes ++: windowKeyType
val partialType = new RowTypeInfo(allFieldTypes.toSeq: _*)
partialType
}
......@@ -591,5 +735,40 @@ object AggregateUtil {
groupingOffsetMapping.toArray
}
private def getTimeFieldPosition(
timeField: Expression,
inputType: RelDataType,
isParserCaseSensitive: Boolean): Int = {
timeField match {
case ResolvedFieldReference(name, _) =>
// get the RelDataType referenced by the time-field
val relDataType = inputType.getFieldList.filter { r =>
if (isParserCaseSensitive) {
name.equals(r.getName)
} else {
name.equalsIgnoreCase(r.getName)
}
}
// should only match one
if (relDataType.length == 1) {
relDataType.head.getIndex
} else {
throw TableException(
s"Encountered more than one time attribute with the same name: $relDataType")
}
case e => throw TableException(
"The time attribute of window in batch environment should be " +
s"ResolvedFieldReference, but is $e")
}
}
private def asLong(expr: Expression): Long = expr match {
case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => value
case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value
case _ => throw new IllegalArgumentException()
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]].
* It is only used for tumbling count-window on batch.
*
* @param windowSize Tumble count window size
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param intermediateRowArity The intermediate row field count
* @param finalRowArity The output row field count
*/
class DataSetTumbleCountWindowAggReduceGroupFunction(
private val windowSize: Long,
private val aggregates: Array[Aggregate[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
private var aggregateBuffer: Row = _
private var output: Row = _
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
var count: Long = 0
val iterator = records.iterator()
while (iterator.hasNext) {
val record = iterator.next()
if (count == 0) {
// initiate intermediate aggregate value.
aggregates.foreach(_.initiate(aggregateBuffer))
}
// merge intermediate aggregate value to buffer.
aggregates.foreach(_.merge(record, aggregateBuffer))
count += 1
if (windowSize == count) {
// set group keys value to final output.
groupKeysMapping.foreach {
case (after, previous) =>
output.setField(after, record.getField(previous))
}
// evaluate final aggregate value and set to output.
aggregateMapping.foreach {
case (after, previous) =>
output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
}
// emit the output
out.collect(output)
count = 0
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.CombineFunction
import org.apache.flink.types.Row
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
* [[org.apache.flink.api.java.operators.GroupCombineOperator]].
* It is used for tumbling time-window on batch.
*
* @param rowtimePos The rowtime field index in input row
* @param windowSize Tumbling time window size
* @param windowStartPos The relative window-start field position to the last field of output row
* @param windowEndPos The relative window-end field position to the last field of output row
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param intermediateRowArity The intermediate row field count
* @param finalRowArity The output row field count
*/
class DataSetTumbleTimeWindowAggReduceCombineFunction(
rowtimePos: Int,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
aggregates: Array[Aggregate[_ <: Any]],
groupKeysMapping: Array[(Int, Int)],
aggregateMapping: Array[(Int, Int)],
intermediateRowArity: Int,
finalRowArity: Int)
extends DataSetTumbleTimeWindowAggReduceGroupFunction(
rowtimePos,
windowSize,
windowStartPos,
windowEndPos,
aggregates,
groupKeysMapping,
aggregateMapping,
intermediateRowArity,
finalRowArity)
with CombineFunction[Row, Row] {
/**
* For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer,
*
* @param records Sub-grouped intermediate aggregate Rows iterator.
* @return Combined intermediate aggregate Row.
*
*/
override def combine(records: Iterable[Row]): Row = {
// initiate intermediate aggregate value.
aggregates.foreach(_.initiate(aggregateBuffer))
// merge intermediate aggregate value to buffer.
var last: Row = null
val iterator = records.iterator()
while (iterator.hasNext) {
val record = iterator.next()
aggregates.foreach(_.merge(record, aggregateBuffer))
last = record
}
// set group keys to aggregateBuffer.
for (i <- groupKeysMapping.indices) {
aggregateBuffer.setField(i, last.getField(i))
}
// set the rowtime attribute
aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos))
aggregateBuffer
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window
* on batch.
*
* @param rowtimePos The rowtime field index in input row
* @param windowSize Tumbling time window size
* @param windowStartPos The relative window-start field position to the last field of output row
* @param windowEndPos The relative window-end field position to the last field of output row
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param intermediateRowArity The intermediate row field count
* @param finalRowArity The output row field count
*/
class DataSetTumbleTimeWindowAggReduceGroupFunction(
rowtimePos: Int,
windowSize: Long,
windowStartPos: Option[Int],
windowEndPos: Option[Int],
aggregates: Array[Aggregate[_ <: Any]],
groupKeysMapping: Array[(Int, Int)],
aggregateMapping: Array[(Int, Int)],
intermediateRowArity: Int,
finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
private var collector: TimeWindowPropertyCollector = _
protected var aggregateBuffer: Row = _
private var output: Row = _
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
}
override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
// initiate intermediate aggregate value.
aggregates.foreach(_.initiate(aggregateBuffer))
// merge intermediate aggregate value to buffer.
var last: Row = null
val iterator = records.iterator()
while (iterator.hasNext) {
val record = iterator.next()
aggregates.foreach(_.merge(record, aggregateBuffer))
last = record
}
// set group keys value to final output.
groupKeysMapping.foreach {
case (after, previous) =>
output.setField(after, last.getField(previous))
}
// evaluate final aggregate value and set to output.
aggregateMapping.foreach {
case (after, previous) =>
output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
}
// get window start timestamp
val startTs: Long = last.getField(rowtimePos).asInstanceOf[Long]
// set collector and window
collector.wrappedCollector = out
collector.timeWindow = new TimeWindow(startTs, startTs + windowSize)
collector.collect(output)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggregate
import java.sql.Timestamp
import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.types.Row
import org.apache.flink.util.Preconditions
/**
* This map function only works for windows on batch tables. The differences between this function
* and [[org.apache.flink.table.runtime.aggregate.AggregateMapFunction]] is this function
* append an (aligned) rowtime field to the end of the output row.
*/
class DataSetWindowAggregateMapFunction(
private val aggregates: Array[Aggregate[_]],
private val aggFields: Array[Int],
private val groupingKeys: Array[Int],
private val timeFieldPos: Int, // time field position in input row
private val tumbleTimeWindowSize: Option[Long],
@transient private val returnType: TypeInformation[Row])
extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] {
private var output: Row = _
// rowtime index in the buffer output row
private var rowtimeIndex: Int = _
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.length == aggFields.length)
// add one more arity to store rowtime
val partialRowLength = groupingKeys.length +
aggregates.map(_.intermediateDataType.length).sum + 1
// set rowtime to the last field of the output row
rowtimeIndex = partialRowLength - 1
output = new Row(partialRowLength)
}
override def map(input: Row): Row = {
for (i <- aggregates.indices) {
val fieldValue = input.getField(aggFields(i))
aggregates(i).prepare(fieldValue, output)
}
for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
val timeField = input.getField(timeFieldPos)
val rowtime = getTimestamp(timeField)
if (tumbleTimeWindowSize.isDefined) {
// in case of tumble time window, align rowtime to window start to represent the window
output.setField(
rowtimeIndex,
TimeWindow.getWindowStartWithOffset(rowtime, 0L, tumbleTimeWindowSize.get))
} else {
// otherwise, set rowtime for future use
output.setField(rowtimeIndex, rowtime)
}
output
}
private def getTimestamp(timeField: Any): Long = {
timeField match {
case b: Byte => b.toLong
case t: Character => t.toLong
case s: Short => s.toLong
case i: Int => i.toLong
case l: Long => l
case f: Float => f.toLong
case d: Double => d.toLong
case s: String => s.toLong
case t: Timestamp => t.getTime
case _ =>
throw new RuntimeException(
s"Window time field doesn't support ${timeField.getClass} type currently")
}
}
override def getProducedType: TypeInformation[Row] = {
returnType
}
}
......@@ -341,5 +341,4 @@ class AggregationsITCase(
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
}
......@@ -303,34 +303,6 @@ class FieldProjectionTest extends TableTestBase {
util.verifyTable(resultTable, expected)
}
@Test(expected = classOf[ValidationException])
def testSelectFromBatchWindow1(): Unit = {
val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
// time field is selected
val resultTable = sourceTable
.window(Tumble over 5.millis on 'a as 'w)
.select('a.sum, 'c.count)
val expected = "TODO"
util.verifyTable(resultTable, expected)
}
@Test(expected = classOf[ValidationException])
def testSelectFromBatchWindow2(): Unit = {
val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
// time field is not selected
val resultTable = sourceTable
.window(Tumble over 5.millis on 'a as 'w)
.select('c.count)
val expected = "TODO"
util.verifyTable(resultTable, expected)
}
}
object FieldProjectionTest {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.api.scala.batch.table
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.utils.TableTestBase
import org.apache.flink.table.utils.TableTestUtil._
import org.junit.Test
class GroupWindowTest extends TableTestBase {
//===============================================================================================
// Tumbling Windows
//===============================================================================================
@Test(expected = classOf[ValidationException])
def testProcessingTimeTumblingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.groupBy('string)
.window(Tumble over 50.milli) // require a time attribute
.select('string, 'int.count)
}
@Test(expected = classOf[ValidationException])
def testProcessingTimeTumblingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.groupBy('string)
.window(Tumble over 2.rows) // require a time attribute
.select('string, 'int.count)
}
@Test
def testEventTimeTumblingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Tumble over 2.rows on 'long)
.select('string, 'int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
batchTableNode(0),
term("groupBy", "string"),
term("window", EventTimeTumblingGroupWindow(None, 'long, 2.rows)),
term("select", "string", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test
def testEventTimeTumblingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Tumble over 5.milli on 'long)
.select('string, 'int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
batchTableNode(0),
term("groupBy", "string"),
term("window", EventTimeTumblingGroupWindow(None, 'long, 5.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test(expected = classOf[ValidationException])
def testAllProcessingTimeTumblingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.window(Tumble over 50.milli) // require a time attribute
.select('string, 'int.count)
}
@Test(expected = classOf[ValidationException])
def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.window(Tumble over 2.rows) // require a time attribute
.select('int.count)
}
@Test
def testAllEventTimeTumblingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.window(Tumble over 5.milli on 'long)
.select('int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "int", "long")
),
term("window", EventTimeTumblingGroupWindow(None, 'long, 5.milli)),
term("select", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test
def testAllEventTimeTumblingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.window(Tumble over 2.rows on 'long)
.select('int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "int", "long")
),
term("window", EventTimeTumblingGroupWindow(None, 'long, 2.rows)),
term("select", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
//===============================================================================================
// Sliding Windows
//===============================================================================================
@Test(expected = classOf[ValidationException])
def testProcessingTimeSlidingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.groupBy('string)
.window(Slide over 50.milli every 50.milli) // require on a time attribute
.select('string, 'int.count)
}
@Test(expected = classOf[ValidationException])
def testProcessingTimeSlidingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.groupBy('string)
.window(Slide over 10.rows every 5.rows) // require on a time attribute
.select('string, 'int.count)
}
@Test
def testEventTimeSlidingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Slide over 8.milli every 10.milli on 'long)
.select('string, 'int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
batchTableNode(0),
term("groupBy", "string"),
term("window", EventTimeSlidingGroupWindow(None, 'long, 8.milli, 10.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test
def testEventTimeSlidingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Slide over 2.rows every 1.rows on 'long)
.select('string, 'int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
batchTableNode(0),
term("groupBy", "string"),
term("window", EventTimeSlidingGroupWindow(None, 'long, 2.rows, 1.rows)),
term("select", "string", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test(expected = classOf[ValidationException])
def testAllProcessingTimeSlidingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.window(Slide over 2.rows every 1.rows) // require on a time attribute
.select('int.count)
}
@Test
def testAllEventTimeSlidingGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.window(Slide over 8.milli every 10.milli on 'long)
.select('int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "int", "long")
),
term("window", EventTimeSlidingGroupWindow(None, 'long, 8.milli, 10.milli)),
term("select", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test
def testAllEventTimeSlidingGroupWindowOverCount(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.window(Slide over 2.rows every 1.rows on 'long)
.select('int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "int", "long")
),
term("window", EventTimeSlidingGroupWindow(None, 'long, 2.rows, 1.rows)),
term("select", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
//===============================================================================================
// Session Windows
//===============================================================================================
@Test
def testEventTimeSessionGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Session withGap 7.milli on 'long)
.select('string, 'int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
batchTableNode(0),
term("groupBy", "string"),
term("window", EventTimeSessionGroupWindow(None, 'long, 7.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
@Test
def testAllEventTimeSessionGroupWindowOverTime(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
val windowedTable = table
.window(Session withGap 7.milli on 'long)
.select('int.count)
val expected = unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "int", "long")
),
term("window", EventTimeSessionGroupWindow(None, 'long, 7.milli)),
term("select", "COUNT(int) AS TMP_0")
)
util.verifyTable(windowedTable, expected)
}
}
......@@ -30,17 +30,6 @@ import org.junit.{Ignore, Test}
class GroupWindowTest extends TableTestBase {
// batch windows are not supported yet
@Test(expected = classOf[ValidationException])
def testInvalidBatchWindow(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String)]('long, 'int, 'string)
table
.groupBy('string)
.window(Session withGap 100.milli as 'string)
}
@Test(expected = classOf[ValidationException])
def testInvalidWindowProperty(): Unit = {
val util = streamTestUtil()
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.dataset
import org.apache.flink.api.scala._
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.TestBaseUtils
import org.apache.flink.types.Row
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import scala.collection.JavaConverters._
@RunWith(classOf[Parameterized])
class DataSetWindowAggregateITCase(
mode: TestExecutionMode,
configMode: TableConfigMode)
extends TableProgramsTestBase(mode, configMode) {
val data = List(
(1L, 1, "Hi"),
(2L, 2, "Hallo"),
(3L, 2, "Hello"),
(6L, 3, "Hello"),
(4L, 5, "Hello"),
(16L, 4, "Hello world"),
(8L, 3, "Hello world"))
@Test(expected = classOf[UnsupportedOperationException])
def testAllEventTimeTumblingWindowOverCount(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
// Count tumbling non-grouping window on event-time are currently not supported
table
.window(Tumble over 2.rows on 'long)
.select('int.count)
.toDataSet[Row]
}
@Test
def testEventTimeTumblingGroupWindowOverCount(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Tumble over 2.rows on 'long)
.select('string, 'int.sum)
val expected = "Hello,7\n" + "Hello world,7\n"
val results = windowedTable.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testEventTimeTumblingGroupWindowOverTime(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
val windowedTable = table
.groupBy('string)
.window(Tumble over 5.milli on 'long as 'w)
.select('string, 'int.sum, 'w.start, 'w.end)
val expected = "Hello world,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" +
"Hello world,4,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02\n" +
"Hello,7,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" +
"Hello,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" +
"Hallo,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" +
"Hi,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n"
val results = windowedTable.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testAllEventTimeTumblingWindowOverTime(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string)
val windowedTable = table
.window(Tumble over 5.milli on 'long as 'w)
.select('int.sum, 'w.start, 'w.end)
val expected = "10,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" +
"6,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" +
"4,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02\n"
val results = windowedTable.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册