diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 2a75c6fca3a0509cadccbe66f12eea29d5b9a107..0fcd88df2cd3da94f4b2901c0f8927053629dfa8 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -1279,7 +1279,8 @@ A session window is defined by using the `Session` class as follows: Currently the following features are not supported yet: - Row-count windows on event-time -- Windows on batch tables +- Session windows on batch tables +- Sliding windows on batch tables SQL ---- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 8717431372bc16d1f7de2fb5a3b71a279ee1c945..957f4c5b050b849037c763f58d644408182f2c09 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -802,9 +802,6 @@ class Table( * @return A windowed table. */ def window(groupWindow: GroupWindow): GroupWindowedTable = { - if (tableEnv.isInstanceOf[BatchTableEnvironment]) { - throw new ValidationException(s"Windows on batch tables are currently not supported.") - } new GroupWindowedTable(this, Seq(), groupWindow) } } @@ -872,9 +869,6 @@ class GroupedTable( * @return A windowed table. */ def window(groupWindow: GroupWindow): GroupWindowedTable = { - if (table.tableEnv.isInstanceOf[BatchTableEnvironment]) { - throw new ValidationException(s"Windows on batch tables are currently not supported.") - } new GroupWindowedTable(table, groupKey, groupWindow) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala index 299a850c44072cf1b9981a89e4326a2da867f07a..08859298c038add4f62c52e3fac30cdf4cc34d5b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala @@ -36,7 +36,7 @@ abstract class Attribute extends LeafExpression with NamedExpression { case class UnresolvedFieldReference(name: String) extends Attribute { - override def toString = "\"" + name + override def toString = s"'$name" override private[flink] def withName(newName: String): Attribute = UnresolvedFieldReference(newName) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala index b12e654731a124fe1d5335179029c3b8501e36e8..0bf149c9a8be5afb5734d892f191d2b8dc736ef2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala @@ -55,7 +55,21 @@ abstract class EventTimeGroupWindow( } } -abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name) +abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name) { + override def validate(tableEnv: TableEnvironment): ValidationResult = { + val valid = super.validate(tableEnv) + if (valid.isFailure) { + return valid + } + + tableEnv match { + case b: BatchTableEnvironment => ValidationFailure( + "Window on batch must declare a time attribute over which the query is evaluated.") + case _ => + ValidationSuccess + } + } +} // ------------------------------------------------------------------------------------------------ // Tumbling group windows @@ -107,9 +121,11 @@ case class EventTimeTumblingGroupWindow( super.validate(tableEnv) .orElse(TumblingGroupWindow.validate(tableEnv, size)) .orElse(size match { - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) + if tableEnv.isInstanceOf[StreamTableEnvironment] => ValidationFailure( - "Event-time grouping windows on row intervals are currently not supported.") + "Event-time grouping windows on row intervals in a stream environment " + + "are currently not supported.") case _ => ValidationSuccess }) @@ -196,9 +212,11 @@ case class EventTimeSlidingGroupWindow( super.validate(tableEnv) .orElse(SlidingGroupWindow.validate(tableEnv, size, slide)) .orElse(size match { - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) + if tableEnv.isInstanceOf[StreamTableEnvironment] => ValidationFailure( - "Event-time grouping windows on row intervals are currently not supported.") + "Event-time grouping windows on row intervals in a stream environment " + + "are currently not supported.") case _ => ValidationSuccess }) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala new file mode 100644 index 0000000000000000000000000000000000000000..79497e6f4046d442d6296521ea86016bb3bacdaa --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.dataset + +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.DataSet +import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo} +import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.plan.nodes.FlinkAggregate +import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _} +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.table.typeutils.TypeConverter +import org.apache.flink.types.Row + +import scala.collection.JavaConversions._ + +/** + * Flink RelNode which matches along with a LogicalWindowAggregate. + */ +class DataSetWindowAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + grouping: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) + with FlinkAggregate + with DataSetRel { + + override def deriveRowType() = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataSetWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + grouping) + } + + override def toString: String = { + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + } else { + "" + } + }window: ($window), " + + s"select: (${ + aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + }))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .item("window", window) + .item( + "select", aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties)) + } + + override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val child = this.getInput + val rowCnt = metadata.getRowCount(child) + val rowSize = this.estimateRowSize(child.getRowType) + val aggCnt = this.namedAggregates.size + planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize) + } + + override def translateToPlan( + tableEnv: BatchTableEnvironment, + expectedType: Option[TypeInformation[Any]]): DataSet[Any] = { + + val config = tableEnv.getConfig + + val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan( + tableEnv, + // tell the input operator that this operator currently only supports Rows as input + Some(TypeConverter.DEFAULT_ROW_TYPE)) + + // whether identifiers are matched case-sensitively + val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive() + val result = window match { + case EventTimeTumblingGroupWindow(_, _, size) => + createEventTimeTumblingWindowDataSet( + inputDS, + isTimeInterval(size.resultType), + caseSensitive) + + case EventTimeSessionGroupWindow(_, _, _) => + throw new UnsupportedOperationException( + "Event-time session windows in a batch environment are currently not supported") + case EventTimeSlidingGroupWindow(_, _, _, _) => + throw new UnsupportedOperationException( + "Event-time sliding windows in a batch environment are currently not supported") + case _: ProcessingTimeGroupWindow => + throw new UnsupportedOperationException( + "Processing-time tumbling windows are not supported in a batch environment, " + + "windows in a batch environment must declare a time attribute over which " + + "the query is evaluated.") + } + + // if the expected type is not a Row, inject a mapper to convert to the expected type + expectedType match { + case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] => + val mapName = s"convert: (${getRowType.getFieldNames.toList.mkString(", ")})" + result.map( + getConversionMapper( + config = config, + nullableInput = false, + inputType = resultRowTypeInfo.asInstanceOf[TypeInformation[Any]], + expectedType = expectedType.get, + conversionOperatorName = "DataSetWindowAggregateConversion", + fieldNames = getRowType.getFieldNames + )) + .name(mapName) + case _ => result + } + } + + + private def createEventTimeTumblingWindowDataSet( + inputDS: DataSet[Any], + isTimeWindow: Boolean, + isParserCaseSensitive: Boolean) + : DataSet[Any] = { + val mapFunction = createDataSetWindowPrepareMapFunction( + window, + namedAggregates, + grouping, + inputType, + isParserCaseSensitive) + val groupReduceFunction = createDataSetWindowAggGroupReduceFunction( + window, + namedAggregates, + inputType, + getRowType, + grouping, + namedProperties) + + val mappedInput = inputDS + .map(mapFunction) + .name(prepareOperatorName) + + val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType + if (isTimeWindow) { + // grouped time window aggregation + // group by grouping keys and rowtime field (the last field in the row) + val groupingKeys = grouping.indices ++ Seq(mapReturnType.getArity - 1) + mappedInput.asInstanceOf[DataSet[Row]] + .groupBy(groupingKeys: _*) + .reduceGroup(groupReduceFunction) + .returns(resultRowTypeInfo) + .name(aggregateOperatorName) + .asInstanceOf[DataSet[Any]] + } else { + // count window + val groupingKeys = grouping.indices.toArray + if (groupingKeys.length > 0) { + // grouped aggregation + mappedInput.asInstanceOf[DataSet[Row]] + .groupBy(groupingKeys: _*) + // sort on time field, it's the last element in the row + .sortGroup(mapReturnType.getArity - 1, Order.ASCENDING) + .reduceGroup(groupReduceFunction) + .returns(resultRowTypeInfo) + .name(aggregateOperatorName) + .asInstanceOf[DataSet[Any]] + + } else { + // TODO: count tumbling all window on event-time should sort all the data set + // on event time before applying the windowing logic. + throw new UnsupportedOperationException( + "Count tumbling non-grouping window on event-time are currently not supported.") + } + } + } + + private def prepareOperatorName: String = { + val aggString = aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + s"prepare select: ($aggString)" + } + + private def aggregateOperatorName: String = { + val aggString = aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + if (grouping.length > 0) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"window: ($window), select: ($aggString)" + } else { + s"window: ($window), select: ($aggString)" + } + } + + private def resultRowTypeInfo: RowTypeInfo = { + // get the output types + val fieldTypes: Array[TypeInformation[_]] = getRowType.getFieldList + .map(field => FlinkTypeFactory.toTypeInfo(field.getType)) + .toArray + new RowTypeInfo(fieldTypes: _*) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 8c8b3048539ee9d0417fab9375af5a26f6c9672e..0ea018f90f22278cc89afd070bc9c1a3374f5d42 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -97,6 +97,7 @@ object FlinkRuleSets { CalcMergeRule.INSTANCE, // translate to Flink DataSet nodes + DataSetWindowAggregateRule.INSTANCE, DataSetAggregateRule.INSTANCE, DataSetAggregateWithNullValuesRule.INSTANCE, DataSetCalcRule.INSTANCE, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetWindowAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetWindowAggregateRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..64f9f8b41076cd4137e0136e69f8736e8dbd9d0c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetWindowAggregateRule.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.dataSet + +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate +import org.apache.flink.table.plan.nodes.dataset.{DataSetConvention, DataSetWindowAggregate} + +import scala.collection.JavaConversions._ + +class DataSetWindowAggregateRule + extends ConverterRule( + classOf[LogicalWindowAggregate], + Convention.NONE, + DataSetConvention.INSTANCE, + "DataSetWindowAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: LogicalWindowAggregate = rel.asInstanceOf[LogicalWindowAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(agg.getInput, DataSetConvention.INSTANCE) + + new DataSetWindowAggregate( + agg.getWindow, + agg.getNamedProperties, + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } +} + +object DataSetWindowAggregateRule { + val INSTANCE: RelOptRule = new DataSetWindowAggregateRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 282e6c0f59bcc89927a7fb986651282bed2a2523..1e4828863d2ec5a50e50c57131e73800229d1504 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -23,7 +23,7 @@ import org.apache.calcite.rel.`type`._ import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.sql.{SqlAggFunction, SqlKind} import org.apache.calcite.sql.`type`.SqlTypeName._ -import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName} +import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.fun._ import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction} import org.apache.flink.api.common.typeinfo.TypeInformation @@ -31,12 +31,13 @@ import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} import FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.expressions.{WindowEnd, WindowStart} +import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} -import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.{TableException, Types} +import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} import org.apache.flink.types.Row import scala.collection.JavaConversions._ @@ -87,6 +88,151 @@ object AggregateUtil { mapFunction } + + /** + * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. + * The output of the function contains the grouping keys and the timestamp and the intermediate + * aggregate values of all aggregate function. The timestamp field is aligned to time window + * start and used to be a grouping key in case of time window. In case of count window on + * event-time, the timestamp is not aligned and used to sort. + * + * The output is stored in Row by the following format: + * + * {{{ + * avg(x) aggOffsetInRow = 2 count(z) aggOffsetInRow = 5 + * | | + * v v + * +---------+---------+--------+--------+--------+--------+--------+ + * |groupKey1|groupKey2| sum1 | count1 | sum2 | count2 | rowtime| + * +---------+---------+--------+--------+--------+--------+--------+ + * ^ ^ + * | | + * sum(y) aggOffsetInRow = 4 rowtime to group or sort + * }}} + * + * NOTE: this function is only used for time based window on batch tables. + */ + def createDataSetWindowPrepareMapFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + groupings: Array[Int], + inputType: RelDataType, + isParserCaseSensitive: Boolean): MapFunction[Any, Row] = { + + val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length) + + val mapReturnType: RowTypeInfo = + createAggregateBufferDataType(groupings, aggregates, inputType, Some(Types.LONG)) + + val (timeFieldPos, tumbleTimeWindowSize) = window match { + case EventTimeTumblingGroupWindow(_, time, size) => + val timeFieldPos = getTimeFieldPosition(time, inputType, isParserCaseSensitive) + size match { + case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => + (timeFieldPos, Some(value)) + case _ => (timeFieldPos, None) + } + case EventTimeSessionGroupWindow(_, time, _) => + (getTimeFieldPosition(time, inputType, isParserCaseSensitive), None) + case _ => + throw new UnsupportedOperationException(s"$window is currently not supported on batch") + } + + new DataSetWindowAggregateMapFunction( + aggregates, + aggFieldIndexes, + groupings, + timeFieldPos, + tumbleTimeWindowSize, + mapReturnType).asInstanceOf[MapFunction[Any, Row]] + } + + /** + * Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute window + * aggregates on batch tables. If all aggregates support partial aggregation and is a time + * window, the [[org.apache.flink.api.common.functions.GroupReduceFunction]] implements + * [[org.apache.flink.api.common.functions.CombineFunction]] as well. + * + * NOTE: this function is only used for window on batch tables. + */ + def createDataSetWindowAggGroupReduceFunction( + window: LogicalWindow, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + outputType: RelDataType, + groupings: Array[Int], + properties: Seq[NamedWindowProperty]): RichGroupReduceFunction[Row, Row] = { + + val aggregates = transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + groupings.length)._2 + + // the addition one field is used to store time attribute + val intermediateRowArity = groupings.length + + aggregates.map(_.intermediateDataType.length).sum + 1 + + // the mapping relation between field index of intermediate aggregate Row and output Row. + val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) + + // the mapping relation between aggregate function index in list and its corresponding + // field index in output Row. + val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType) + + if (groupingOffsetMapping.length != groupings.length || + aggOffsetMapping.length != namedAggregates.length) { + throw new TableException( + "Could not find output field in input data type " + + "or aggregate functions.") + } + + window match { + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + // tumbling time window + val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + if (aggregates.forall(_.supportPartial)) { + // for incremental aggregations + new DataSetTumbleTimeWindowAggReduceCombineFunction( + intermediateRowArity - 1, + asLong(size), + startPos, + endPos, + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + intermediateRowArity, + outputType.getFieldCount) + } + else { + // for non-incremental aggregations + new DataSetTumbleTimeWindowAggReduceGroupFunction( + intermediateRowArity - 1, + asLong(size), + startPos, + endPos, + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + intermediateRowArity, + outputType.getFieldCount) + } + case EventTimeTumblingGroupWindow(_, _, size) => + // tumbling count window + new DataSetTumbleCountWindowAggReduceGroupFunction( + asLong(size), + aggregates, + groupingOffsetMapping, + aggOffsetMapping, + intermediateRowArity, + outputType.getFieldCount) + case _ => + throw new UnsupportedOperationException(s"$window is currently not supported on batch") + } + } + /** * Create a [[org.apache.flink.api.common.functions.GroupReduceFunction]] to compute aggregates. * If all aggregates support partial aggregation, the @@ -360,7 +506,7 @@ object AggregateUtil { } } - private def computeWindowStartEndPropertyPos( + private[flink] def computeWindowStartEndPropertyPos( properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { @@ -517,22 +663,20 @@ object AggregateUtil { private def createAggregateBufferDataType( groupings: Array[Int], aggregates: Array[Aggregate[_]], - inputType: RelDataType): RowTypeInfo = { + inputType: RelDataType, + windowKeyType: Option[TypeInformation[_]] = None): RowTypeInfo = { // get the field data types of group keys. val groupingTypes: Seq[TypeInformation[_]] = groupings .map(inputType.getFieldList.get(_).getType) .map(FlinkTypeFactory.toTypeInfo) - val aggPartialNameSuffix = "agg_buffer_" - val factory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT) - // get all field data types of all intermediate aggregates val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType) - // concat group key types and aggregation types - val allFieldTypes = groupingTypes ++: aggTypes - val partialType = new RowTypeInfo(allFieldTypes: _*) + // concat group key types and aggregation types, and window key types (may be empty) + val allFieldTypes = groupingTypes ++: aggTypes ++: windowKeyType + val partialType = new RowTypeInfo(allFieldTypes.toSeq: _*) partialType } @@ -591,5 +735,40 @@ object AggregateUtil { groupingOffsetMapping.toArray } + + private def getTimeFieldPosition( + timeField: Expression, + inputType: RelDataType, + isParserCaseSensitive: Boolean): Int = { + + timeField match { + case ResolvedFieldReference(name, _) => + // get the RelDataType referenced by the time-field + val relDataType = inputType.getFieldList.filter { r => + if (isParserCaseSensitive) { + name.equals(r.getName) + } else { + name.equalsIgnoreCase(r.getName) + } + } + // should only match one + if (relDataType.length == 1) { + relDataType.head.getIndex + } else { + throw TableException( + s"Encountered more than one time attribute with the same name: $relDataType") + } + case e => throw TableException( + "The time attribute of window in batch environment should be " + + s"ResolvedFieldReference, but is $e") + } + } + + private def asLong(expr: Expression): Long = expr match { + case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => value + case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value + case _ => throw new IllegalArgumentException() + } + } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala new file mode 100644 index 0000000000000000000000000000000000000000..40dad171462f674566e625d126bb2bcb258db6f4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.configuration.Configuration +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. + * It is only used for tumbling count-window on batch. + * + * @param windowSize Tumble count window size + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param intermediateRowArity The intermediate row field count + * @param finalRowArity The output row field count + */ +class DataSetTumbleCountWindowAggReduceGroupFunction( + private val windowSize: Long, + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val intermediateRowArity: Int, + private val finalRowArity: Int) + extends RichGroupReduceFunction[Row, Row] { + + private var aggregateBuffer: Row = _ + private var output: Row = _ + + override def open(config: Configuration) { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + aggregateBuffer = new Row(intermediateRowArity) + output = new Row(finalRowArity) + } + + override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { + + var count: Long = 0 + + val iterator = records.iterator() + + while (iterator.hasNext) { + val record = iterator.next() + if (count == 0) { + // initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + } + // merge intermediate aggregate value to buffer. + aggregates.foreach(_.merge(record, aggregateBuffer)) + + count += 1 + if (windowSize == count) { + // set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, record.getField(previous)) + } + // evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + } + // emit the output + out.collect(output) + count = 0 + } + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala new file mode 100644 index 0000000000000000000000000000000000000000..a72c9cafb2de14b415b3f9e204666f4cc769dc95 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.CombineFunction +import org.apache.flink.types.Row + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and + * [[org.apache.flink.api.java.operators.GroupCombineOperator]]. + * It is used for tumbling time-window on batch. + * + * @param rowtimePos The rowtime field index in input row + * @param windowSize Tumbling time window size + * @param windowStartPos The relative window-start field position to the last field of output row + * @param windowEndPos The relative window-end field position to the last field of output row + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param intermediateRowArity The intermediate row field count + * @param finalRowArity The output row field count + */ +class DataSetTumbleTimeWindowAggReduceCombineFunction( + rowtimePos: Int, + windowSize: Long, + windowStartPos: Option[Int], + windowEndPos: Option[Int], + aggregates: Array[Aggregate[_ <: Any]], + groupKeysMapping: Array[(Int, Int)], + aggregateMapping: Array[(Int, Int)], + intermediateRowArity: Int, + finalRowArity: Int) + extends DataSetTumbleTimeWindowAggReduceGroupFunction( + rowtimePos, + windowSize, + windowStartPos, + windowEndPos, + aggregates, + groupKeysMapping, + aggregateMapping, + intermediateRowArity, + finalRowArity) + with CombineFunction[Row, Row] { + + /** + * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * + * @param records Sub-grouped intermediate aggregate Rows iterator. + * @return Combined intermediate aggregate Row. + * + */ + override def combine(records: Iterable[Row]): Row = { + + // initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // merge intermediate aggregate value to buffer. + var last: Row = null + + val iterator = records.iterator() + while (iterator.hasNext) { + val record = iterator.next() + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + } + + // set group keys to aggregateBuffer. + for (i <- groupKeysMapping.indices) { + aggregateBuffer.setField(i, last.getField(i)) + } + + // set the rowtime attribute + aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos)) + + aggregateBuffer + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae123d73d7f3cfa1bbf3c5fece54aa7d0f654fda --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window + * on batch. + * + * @param rowtimePos The rowtime field index in input row + * @param windowSize Tumbling time window size + * @param windowStartPos The relative window-start field position to the last field of output row + * @param windowEndPos The relative window-end field position to the last field of output row + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + * @param intermediateRowArity The intermediate row field count + * @param finalRowArity The output row field count + */ +class DataSetTumbleTimeWindowAggReduceGroupFunction( + rowtimePos: Int, + windowSize: Long, + windowStartPos: Option[Int], + windowEndPos: Option[Int], + aggregates: Array[Aggregate[_ <: Any]], + groupKeysMapping: Array[(Int, Int)], + aggregateMapping: Array[(Int, Int)], + intermediateRowArity: Int, + finalRowArity: Int) + extends RichGroupReduceFunction[Row, Row] { + + private var collector: TimeWindowPropertyCollector = _ + protected var aggregateBuffer: Row = _ + private var output: Row = _ + + override def open(config: Configuration) { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + aggregateBuffer = new Row(intermediateRowArity) + output = new Row(finalRowArity) + collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) + } + + override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { + + // initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // merge intermediate aggregate value to buffer. + var last: Row = null + + val iterator = records.iterator() + while (iterator.hasNext) { + val record = iterator.next() + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + } + + // set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, last.getField(previous)) + } + + // evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + } + + // get window start timestamp + val startTs: Long = last.getField(rowtimePos).asInstanceOf[Long] + + // set collector and window + collector.wrappedCollector = out + collector.timeWindow = new TimeWindow(startTs, startTs + windowSize) + + collector.collect(output) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala new file mode 100644 index 0000000000000000000000000000000000000000..c9fb51424b3124d4a39864cc7151cbd1827d2f1a --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggregateMapFunction.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.sql.Timestamp + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.types.Row +import org.apache.flink.util.Preconditions + + +/** + * This map function only works for windows on batch tables. The differences between this function + * and [[org.apache.flink.table.runtime.aggregate.AggregateMapFunction]] is this function + * append an (aligned) rowtime field to the end of the output row. + */ +class DataSetWindowAggregateMapFunction( + private val aggregates: Array[Aggregate[_]], + private val aggFields: Array[Int], + private val groupingKeys: Array[Int], + private val timeFieldPos: Int, // time field position in input row + private val tumbleTimeWindowSize: Option[Long], + @transient private val returnType: TypeInformation[Row]) + extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] { + + private var output: Row = _ + // rowtime index in the buffer output row + private var rowtimeIndex: Int = _ + + override def open(config: Configuration) { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.length == aggFields.length) + // add one more arity to store rowtime + val partialRowLength = groupingKeys.length + + aggregates.map(_.intermediateDataType.length).sum + 1 + // set rowtime to the last field of the output row + rowtimeIndex = partialRowLength - 1 + output = new Row(partialRowLength) + } + + override def map(input: Row): Row = { + for (i <- aggregates.indices) { + val fieldValue = input.getField(aggFields(i)) + aggregates(i).prepare(fieldValue, output) + } + for (i <- groupingKeys.indices) { + output.setField(i, input.getField(groupingKeys(i))) + } + + val timeField = input.getField(timeFieldPos) + val rowtime = getTimestamp(timeField) + if (tumbleTimeWindowSize.isDefined) { + // in case of tumble time window, align rowtime to window start to represent the window + output.setField( + rowtimeIndex, + TimeWindow.getWindowStartWithOffset(rowtime, 0L, tumbleTimeWindowSize.get)) + } else { + // otherwise, set rowtime for future use + output.setField(rowtimeIndex, rowtime) + } + + output + } + + private def getTimestamp(timeField: Any): Long = { + timeField match { + case b: Byte => b.toLong + case t: Character => t.toLong + case s: Short => s.toLong + case i: Int => i.toLong + case l: Long => l + case f: Float => f.toLong + case d: Double => d.toLong + case s: String => s.toLong + case t: Timestamp => t.getTime + case _ => + throw new RuntimeException( + s"Window time field doesn't support ${timeField.getClass} type currently") + } + } + + override def getProducedType: TypeInformation[Row] = { + returnType + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala index 94c2a5cb1471f318bf25ec6b6d773ac4970b105e..62c70a29a486da7cd7aec958ec424098d1f4d06e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala @@ -341,5 +341,4 @@ class AggregationsITCase( val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } - } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala index cc691d27871e7a6fb537286d3d6d90e58993bd37..708e7f176b33454bbfd290c38b94da46beef879f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -303,34 +303,6 @@ class FieldProjectionTest extends TableTestBase { util.verifyTable(resultTable, expected) } - - @Test(expected = classOf[ValidationException]) - def testSelectFromBatchWindow1(): Unit = { - val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) - - // time field is selected - val resultTable = sourceTable - .window(Tumble over 5.millis on 'a as 'w) - .select('a.sum, 'c.count) - - val expected = "TODO" - - util.verifyTable(resultTable, expected) - } - - @Test(expected = classOf[ValidationException]) - def testSelectFromBatchWindow2(): Unit = { - val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) - - // time field is not selected - val resultTable = sourceTable - .window(Tumble over 5.millis on 'a as 'w) - .select('c.count) - - val expected = "TODO" - - util.verifyTable(resultTable, expected) - } } object FieldProjectionTest { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..e3c36490ba829b888cffe5176c472941dea6e8be --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/GroupWindowTest.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.table.utils.TableTestUtil._ +import org.junit.Test + +class GroupWindowTest extends TableTestBase { + + //=============================================================================================== + // Tumbling Windows + //=============================================================================================== + + @Test(expected = classOf[ValidationException]) + def testProcessingTimeTumblingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .groupBy('string) + .window(Tumble over 50.milli) // require a time attribute + .select('string, 'int.count) + } + + @Test(expected = classOf[ValidationException]) + def testProcessingTimeTumblingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .groupBy('string) + .window(Tumble over 2.rows) // require a time attribute + .select('string, 'int.count) + } + + @Test + def testEventTimeTumblingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Tumble over 2.rows on 'long) + .select('string, 'int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + batchTableNode(0), + term("groupBy", "string"), + term("window", EventTimeTumblingGroupWindow(None, 'long, 2.rows)), + term("select", "string", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test + def testEventTimeTumblingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Tumble over 5.milli on 'long) + .select('string, 'int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + batchTableNode(0), + term("groupBy", "string"), + term("window", EventTimeTumblingGroupWindow(None, 'long, 5.milli)), + term("select", "string", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testAllProcessingTimeTumblingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .window(Tumble over 50.milli) // require a time attribute + .select('string, 'int.count) + } + + @Test(expected = classOf[ValidationException]) + def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .window(Tumble over 2.rows) // require a time attribute + .select('int.count) + } + + @Test + def testAllEventTimeTumblingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 5.milli on 'long) + .select('int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "int", "long") + ), + term("window", EventTimeTumblingGroupWindow(None, 'long, 5.milli)), + term("select", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test + def testAllEventTimeTumblingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 2.rows on 'long) + .select('int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "int", "long") + ), + term("window", EventTimeTumblingGroupWindow(None, 'long, 2.rows)), + term("select", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + //=============================================================================================== + // Sliding Windows + //=============================================================================================== + + @Test(expected = classOf[ValidationException]) + def testProcessingTimeSlidingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .groupBy('string) + .window(Slide over 50.milli every 50.milli) // require on a time attribute + .select('string, 'int.count) + } + + @Test(expected = classOf[ValidationException]) + def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .groupBy('string) + .window(Slide over 10.rows every 5.rows) // require on a time attribute + .select('string, 'int.count) + } + + @Test + def testEventTimeSlidingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Slide over 8.milli every 10.milli on 'long) + .select('string, 'int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + batchTableNode(0), + term("groupBy", "string"), + term("window", EventTimeSlidingGroupWindow(None, 'long, 8.milli, 10.milli)), + term("select", "string", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test + def testEventTimeSlidingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Slide over 2.rows every 1.rows on 'long) + .select('string, 'int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + batchTableNode(0), + term("groupBy", "string"), + term("window", EventTimeSlidingGroupWindow(None, 'long, 2.rows, 1.rows)), + term("select", "string", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testAllProcessingTimeSlidingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + table + .window(Slide over 2.rows every 1.rows) // require on a time attribute + .select('int.count) + } + + @Test + def testAllEventTimeSlidingGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .window(Slide over 8.milli every 10.milli on 'long) + .select('int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "int", "long") + ), + term("window", EventTimeSlidingGroupWindow(None, 'long, 8.milli, 10.milli)), + term("select", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test + def testAllEventTimeSlidingGroupWindowOverCount(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .window(Slide over 2.rows every 1.rows on 'long) + .select('int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "int", "long") + ), + term("window", EventTimeSlidingGroupWindow(None, 'long, 2.rows, 1.rows)), + term("select", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + //=============================================================================================== + // Session Windows + //=============================================================================================== + + @Test + def testEventTimeSessionGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Session withGap 7.milli on 'long) + .select('string, 'int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + batchTableNode(0), + term("groupBy", "string"), + term("window", EventTimeSessionGroupWindow(None, 'long, 7.milli)), + term("select", "string", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } + + @Test + def testAllEventTimeSessionGroupWindowOverTime(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) + + val windowedTable = table + .window(Session withGap 7.milli on 'long) + .select('int.count) + + val expected = unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "int", "long") + ), + term("window", EventTimeSessionGroupWindow(None, 'long, 7.milli)), + term("select", "COUNT(int) AS TMP_0") + ) + + util.verifyTable(windowedTable, expected) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala index ee24cf779e539f27f04169a708a70bce7c786aaa..cbd814a81c92d7d03d768dcb6732b5e96ff4e3da 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowTest.scala @@ -30,17 +30,6 @@ import org.junit.{Ignore, Test} class GroupWindowTest extends TableTestBase { - // batch windows are not supported yet - @Test(expected = classOf[ValidationException]) - def testInvalidBatchWindow(): Unit = { - val util = batchTestUtil() - val table = util.addTable[(Long, Int, String)]('long, 'int, 'string) - - table - .groupBy('string) - .window(Session withGap 100.milli as 'string) - } - @Test(expected = classOf[ValidationException]) def testInvalidWindowProperty(): Unit = { val util = streamTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala new file mode 100644 index 0000000000000000000000000000000000000000..fbdbec416a857464c0e1b84d032c3e37bc91e31a --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetWindowAggregateITCase.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.dataset + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class DataSetWindowAggregateITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + val data = List( + (1L, 1, "Hi"), + (2L, 2, "Hallo"), + (3L, 2, "Hello"), + (6L, 3, "Hello"), + (4L, 5, "Hello"), + (16L, 4, "Hello world"), + (8L, 3, "Hello world")) + + @Test(expected = classOf[UnsupportedOperationException]) + def testAllEventTimeTumblingWindowOverCount(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string) + + // Count tumbling non-grouping window on event-time are currently not supported + table + .window(Tumble over 2.rows on 'long) + .select('int.count) + .toDataSet[Row] + } + + @Test + def testEventTimeTumblingGroupWindowOverCount(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Tumble over 2.rows on 'long) + .select('string, 'int.sum) + + val expected = "Hello,7\n" + "Hello world,7\n" + val results = windowedTable.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testEventTimeTumblingGroupWindowOverTime(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .groupBy('string) + .window(Tumble over 5.milli on 'long as 'w) + .select('string, 'int.sum, 'w.start, 'w.end) + + val expected = "Hello world,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" + + "Hello world,4,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02\n" + + "Hello,7,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" + + "Hello,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" + + "Hallo,2,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" + + "Hi,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" + + val results = windowedTable.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAllEventTimeTumblingWindowOverTime(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val table = env.fromCollection(data).toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 5.milli on 'long as 'w) + .select('int.sum, 'w.start, 'w.end) + + val expected = "10,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005\n" + + "6,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01\n" + + "4,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02\n" + + val results = windowedTable.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } +}