提交 ef8cdfe5 编写于 作者: A Aleksandr Chermenin 提交者: twalthr

[FLINK-5303] [table] Add CUBE/ROLLUP/GROUPING SETS operator in SQL

This closes #2976.
上级 68228ffd
......@@ -1334,7 +1334,7 @@ Among others, the following SQL features are not supported, yet:
- Interval arithmetic is currenly limited
- Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
- Non-equi joins and Cartesian products
- Grouping sets
- Efficient grouping sets
*Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products.*
......@@ -1442,7 +1442,9 @@ groupItem:
expression
| '(' ')'
| '(' expression [, expression ]* ')'
| CUBE '(' expression [, expression ]* ')'
| ROLLUP '(' expression [, expression ]* ')'
| GROUPING SETS '(' groupItem [, groupItem ]* ')'
```
For a better definition of SQL queries within a Java String, Flink SQL uses a lexical policy similar to Java:
......@@ -3759,6 +3761,50 @@ MIN(value)
</tbody>
</table>
<table class="table table-bordered">
<thead>
<tr>
<th class="text-left" style="width: 40%">Grouping functions</th>
<th class="text-center">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>
{% highlight text %}
GROUP_ID()
{% endhighlight %}
</td>
<td>
<p>Returns an integer that uniquely identifies the combination of grouping keys.</p>
</td>
</tr>
<tr>
<td>
{% highlight text %}
GROUPING(expression)
{% endhighlight %}
</td>
<td>
<p>Returns 1 if <i>expression</i> is rolled up in the current row’s grouping set, 0 otherwise.</p>
</td>
</tr>
<tr>
<td>
{% highlight text %}
GROUPING_ID(expression [, expression]* )
{% endhighlight %}
</td>
<td>
<p>Returns a bit vector of the given grouping expressions.</p>
</td>
</tr>
</tbody>
</table>
<table class="table table-bordered">
<thead>
<tr>
......
......@@ -377,9 +377,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val prefixed: PackratParser[Expression] =
prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg |
prefixStart | prefixEnd |
prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract |
prefixFloor | prefixCeil | prefixGet | prefixFlattening |
prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs |
prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening |
prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end
// suffix/prefix composite
......
......@@ -46,12 +46,13 @@ class DataSetAggregate(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
grouping: Array[Int],
inGroupingSet: Boolean)
extends SingleRel(cluster, traitSet, inputNode)
with FlinkAggregate
with DataSetRel {
override def deriveRowType() = rowRelDataType
override def deriveRowType(): RelDataType = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetAggregate(
......@@ -61,7 +62,8 @@ class DataSetAggregate(
namedAggregates,
getRowType,
inputType,
grouping)
grouping,
inGroupingSet)
}
override def toString: String = {
......@@ -104,7 +106,8 @@ class DataSetAggregate(
namedAggregates,
inputType,
rowRelDataType,
grouping)
grouping,
inGroupingSet)
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(
tableEnv,
......
......@@ -58,7 +58,7 @@ class DataStreamAggregate(
with FlinkAggregate
with DataStreamRel {
override def deriveRowType() = rowRelDataType
override def deriveRowType(): RelDataType = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamAggregate(
......@@ -242,6 +242,7 @@ class DataStreamAggregate(
}
}
}
// if the expected type is not a Row, inject a mapper to convert to the expected type
expectedType match {
case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
......
......@@ -23,23 +23,21 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.logical.LogicalAggregate
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention, DataSetUnion}
import scala.collection.JavaConversions._
class DataSetAggregateRule
extends ConverterRule(
classOf[LogicalAggregate],
Convention.NONE,
DataSetConvention.INSTANCE,
"DataSetAggregateRule")
{
classOf[LogicalAggregate],
Convention.NONE,
DataSetConvention.INSTANCE,
"DataSetAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
//for non grouped agg sets should attach null row to source data
//need apply DataSetAggregateWithNullValuesRule
// for non-grouped agg sets we attach null row to source data
// we need to apply DataSetAggregateWithNullValuesRule
if (agg.getGroupSet.isEmpty) {
return false
}
......@@ -50,13 +48,7 @@ class DataSetAggregateRule
throw TableException("DISTINCT aggregates are currently not supported.")
}
// check if we have grouping sets
val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet
if (groupSets || agg.indicator) {
throw TableException("GROUPING SETS are currently not supported.")
}
!distinctAggs && !groupSets && !agg.indicator
!distinctAggs
}
override def convert(rel: RelNode): RelNode = {
......@@ -64,16 +56,43 @@ class DataSetAggregateRule
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
val convInput: RelNode = RelOptRule.convert(agg.getInput, DataSetConvention.INSTANCE)
new DataSetAggregate(
rel.getCluster,
traitSet,
convInput,
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
agg.getGroupSet.toArray)
if (agg.indicator) {
agg.groupSets.map(set =>
new DataSetAggregate(
rel.getCluster,
traitSet,
convInput,
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
set.toArray,
inGroupingSet = true
).asInstanceOf[RelNode]
).reduce(
(rel1, rel2) => {
new DataSetUnion(
rel.getCluster,
traitSet,
rel1,
rel2,
rel.getRowType
)
}
)
} else {
new DataSetAggregate(
rel.getCluster,
traitSet,
convInput,
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
agg.getGroupSet.toArray,
inGroupingSet = false
)
}
}
}
object DataSetAggregateRule {
val INSTANCE: RelOptRule = new DataSetAggregateRule
......
......@@ -29,22 +29,21 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
/**
* Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]]
* Rule apply for non grouped aggregate query
* Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]].
* Rule apply for non grouped aggregate query.
*/
class DataSetAggregateWithNullValuesRule
extends ConverterRule(
classOf[LogicalAggregate],
Convention.NONE,
DataSetConvention.INSTANCE,
"DataSetAggregateWithNullValuesRule")
{
"DataSetAggregateWithNullValuesRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
//for grouped agg sets shouldn't attach of null row
//need apply other rules. e.g. [[DataSetAggregateRule]]
// group sets shouldn't attach a null row
// we need to apply other rules. i.e. DataSetAggregateRule
if (!agg.getGroupSet.isEmpty) {
return false
}
......@@ -55,12 +54,7 @@ class DataSetAggregateWithNullValuesRule
throw TableException("DISTINCT aggregates are currently not supported.")
}
// check if we have grouping sets
val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet
if (groupSets || agg.indicator) {
throw TableException("GROUPING SETS are currently not supported.")
}
!distinctAggs && !groupSets && !agg.indicator
!distinctAggs
}
override def convert(rel: RelNode): RelNode = {
......@@ -87,7 +81,8 @@ class DataSetAggregateWithNullValuesRule
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
agg.getGroupSet.toArray
agg.getGroupSet.toArray,
inGroupingSet = false
)
}
}
......
......@@ -33,8 +33,7 @@ class DataStreamAggregateRule
classOf[LogicalWindowAggregate],
Convention.NONE,
DataStreamConvention.INSTANCE,
"DataStreamAggregateRule")
{
"DataStreamAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate]
......@@ -75,4 +74,3 @@ class DataStreamAggregateRule
object DataStreamAggregateRule {
val INSTANCE: RelOptRule = new DataStreamAggregateRule
}
......@@ -37,7 +37,7 @@ class AggregateMapFunction[IN, OUT](
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
Preconditions.checkArgument(aggregates.size == aggFields.size)
Preconditions.checkArgument(aggregates.length == aggFields.length)
val partialRowLength = groupingKeys.length +
aggregates.map(_.intermediateDataType.length).sum
output = new Row(partialRowLength)
......@@ -46,11 +46,11 @@ class AggregateMapFunction[IN, OUT](
override def map(value: IN): OUT = {
val input = value.asInstanceOf[Row]
for (i <- 0 until aggregates.length) {
for (i <- aggregates.indices) {
val fieldValue = input.getField(aggFields(i))
aggregates(i).prepare(fieldValue, output)
}
for (i <- 0 until groupingKeys.length) {
for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
output.asInstanceOf[OUT]
......
......@@ -25,28 +25,31 @@ import org.apache.flink.types.Row
import scala.collection.JavaConversions._
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
* [[org.apache.flink.api.java.operators.GroupCombineOperator]]
*
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
* Row and output Row.
*/
class AggregateReduceCombineFunction(
private val aggregates: Array[Aggregate[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends AggregateReduceGroupFunction(
aggregates,
groupKeysMapping,
aggregateMapping,
groupingSetsMapping,
intermediateRowArity,
finalRowArity)
with CombineFunction[Row, Row] {
......
......@@ -20,38 +20,45 @@ package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
import scala.collection.JavaConversions._
/**
* It wraps the aggregate logic inside of
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]].
*
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param aggregates The aggregate functions.
* @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
* and output Row.
* @param aggregateMapping The index mapping between aggregate function list and aggregated value
* index in output Row.
* @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
* Row and output Row.
*/
class AggregateReduceGroupFunction(
private val aggregates: Array[Aggregate[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
protected var aggregateBuffer: Row = _
private var output: Row = _
private var intermediateGroupKeys: Option[Array[Int]] = None
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
if (!groupingSetsMapping.isEmpty) {
intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
}
}
/**
......@@ -87,6 +94,14 @@ class AggregateReduceGroupFunction(
output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
}
// Evaluate additional values of grouping sets
if (intermediateGroupKeys.isDefined) {
groupingSetsMapping.foreach {
case (inputIndex, outputIndex) =>
output.setField(outputIndex, !intermediateGroupKeys.get.contains(inputIndex))
}
}
out.collect(output)
}
}
......@@ -241,10 +241,12 @@ object AggregateUtil {
*
*/
private[flink] def createAggregateGroupReduceFunction(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
outputType: RelDataType,
groupings: Array[Int]): RichGroupReduceFunction[Row, Row] = {
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
outputType: RelDataType,
groupings: Array[Int],
inGroupingSet: Boolean)
: RichGroupReduceFunction[Row, Row] = {
val aggregates = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
......@@ -258,6 +260,12 @@ object AggregateUtil {
outputType,
groupings)
val groupingSetsMapping: Array[(Int, Int)] = if (inGroupingSet) {
getGroupingSetsIndicatorMapping(inputType, outputType)
} else {
Array()
}
val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial)
val intermediateRowArity = groupings.length +
......@@ -269,6 +277,7 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
groupingSetsMapping,
intermediateRowArity,
outputType.getFieldCount)
}
......@@ -277,6 +286,7 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
groupingSetsMapping,
intermediateRowArity,
outputType.getFieldCount)
}
......@@ -329,7 +339,8 @@ object AggregateUtil {
namedAggregates,
inputType,
outputType,
groupings)
groupings,
inGroupingSet = false)
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
......@@ -358,7 +369,8 @@ object AggregateUtil {
namedAggregates,
inputType,
outputType,
groupings)
groupings,
inGroupingSet = false)
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
......@@ -371,7 +383,7 @@ object AggregateUtil {
/**
* Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned
* window aggreagtes.
* window aggregates.
*/
private[flink] def createAllWindowIncrementalAggregationFunction(
window: LogicalWindow,
......@@ -495,6 +507,51 @@ object AggregateUtil {
(groupingOffsetMapping, aggOffsetMapping)
}
/**
* Determines the mapping of grouping keys to boolean indicators that describe the
* current grouping set.
*
* E.g.: Given we group on f1 and f2 of the input type, the output type contains two
* boolean indicator fields i$f1 and i$f2.
*/
private def getGroupingSetsIndicatorMapping(
inputType: RelDataType,
outputType: RelDataType)
: Array[(Int, Int)] = {
val inputFields = inputType.getFieldList.map(_.getName)
// map from field -> i$field or field -> i$field_0
val groupingFields = inputFields.map(inputFieldName => {
val base = "i$" + inputFieldName
var name = base
var i = 0
while (inputFields.contains(name)) {
name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER
i = i + 1
}
inputFieldName -> name
}).toMap
val outputFields = outputType.getFieldList
var mappingsBuffer = ArrayBuffer[(Int, Int)]()
for (i <- outputFields.indices) {
for (j <- outputFields.indices) {
val possibleKey = outputFields(i).getName
val possibleIndicator1 = outputFields(j).getName
// get indicator for output field
val possibleIndicator2 = groupingFields.getOrElse(possibleKey, null)
// check if indicator names match
if (possibleIndicator1 == possibleIndicator2) {
mappingsBuffer += ((i, j))
}
}
}
mappingsBuffer.toArray
}
private def isTimeWindow(window: LogicalWindow) = {
window match {
case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType)
......
......@@ -257,6 +257,10 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.NOT,
SqlStdOperatorTable.UNARY_MINUS,
SqlStdOperatorTable.UNARY_PLUS,
// GROUPING FUNCTIONS
SqlStdOperatorTable.GROUP_ID,
SqlStdOperatorTable.GROUPING,
SqlStdOperatorTable.GROUPING_ID,
// AGGREGATE OPERATORS
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.COUNT,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.api.java.batch.sql;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.java.BatchTableEnvironment;
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Comparator;
import java.util.List;
@RunWith(Parameterized.class)
public class GroupingSetsITCase extends TableProgramsTestBase {
private final static String TABLE_NAME = "MyTable";
private final static String TABLE_WITH_NULLS_NAME = "MyTableWithNulls";
private BatchTableEnvironment tableEnv;
public GroupingSetsITCase(TestExecutionMode mode, TableConfigMode tableConfigMode) {
super(mode, tableConfigMode);
}
@Before
public void setupTables() {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
tableEnv = TableEnvironment.getTableEnvironment(env, new TableConfig());
DataSet<Tuple3<Integer, Long, String>> dataSet = CollectionDataSets.get3TupleDataSet(env);
tableEnv.registerDataSet(TABLE_NAME, dataSet);
MapOperator<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> dataSetWithNulls =
dataSet.map(new MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>>() {
@Override
public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> value) throws Exception {
if (value.f2.toLowerCase().contains("world")) {
value.f2 = null;
}
return value;
}
});
tableEnv.registerDataSet(TABLE_WITH_NULLS_NAME, dataSetWithNulls);
}
@Test
public void testGroupingSets() throws Exception {
String query =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
" GROUPING_ID(f1, f2) as gid " +
" FROM " + TABLE_NAME +
" GROUP BY GROUPING SETS (f1, f2)";
String expected =
"1,null,1,1,0,1,0,1,1\n" +
"6,null,18,1,0,1,0,1,1\n" +
"2,null,2,1,0,1,0,1,1\n" +
"4,null,8,1,0,1,0,1,1\n" +
"5,null,13,1,0,1,0,1,1\n" +
"3,null,5,1,0,1,0,1,1\n" +
"null,Comment#11,17,2,1,0,1,0,2\n" +
"null,Comment#8,14,2,1,0,1,0,2\n" +
"null,Comment#2,8,2,1,0,1,0,2\n" +
"null,Comment#1,7,2,1,0,1,0,2\n" +
"null,Comment#14,20,2,1,0,1,0,2\n" +
"null,Comment#7,13,2,1,0,1,0,2\n" +
"null,Comment#6,12,2,1,0,1,0,2\n" +
"null,Comment#3,9,2,1,0,1,0,2\n" +
"null,Comment#12,18,2,1,0,1,0,2\n" +
"null,Comment#5,11,2,1,0,1,0,2\n" +
"null,Comment#15,21,2,1,0,1,0,2\n" +
"null,Comment#4,10,2,1,0,1,0,2\n" +
"null,Hi,1,2,1,0,1,0,2\n" +
"null,Comment#10,16,2,1,0,1,0,2\n" +
"null,Hello world,3,2,1,0,1,0,2\n" +
"null,I am fine.,5,2,1,0,1,0,2\n" +
"null,Hello world, how are you?,4,2,1,0,1,0,2\n" +
"null,Comment#9,15,2,1,0,1,0,2\n" +
"null,Comment#13,19,2,1,0,1,0,2\n" +
"null,Luke Skywalker,6,2,1,0,1,0,2\n" +
"null,Hello,2,2,1,0,1,0,2";
checkSql(query, expected);
}
@Test
public void testGroupingSetsWithNulls() throws Exception {
String query =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g FROM " + TABLE_WITH_NULLS_NAME +
" GROUP BY GROUPING SETS (f1, f2)";
String expected =
"6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
"null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
"null,null,3,2\nnull,Hello,2,2\nnull,Comment#9,15,2\nnull,Comment#8,14,2\n" +
"null,Comment#7,13,2\nnull,Comment#6,12,2\nnull,Comment#5,11,2\n" +
"null,Comment#4,10,2\nnull,Comment#3,9,2\nnull,Comment#2,8,2\n" +
"null,Comment#15,21,2\nnull,Comment#14,20,2\nnull,Comment#13,19,2\n" +
"null,Comment#12,18,2\nnull,Comment#11,17,2\nnull,Comment#10,16,2\n" +
"null,Comment#1,7,2";
checkSql(query, expected);
}
@Test
public void testCubeAsGroupingSets() throws Exception {
String cubeQuery =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
" GROUPING_ID(f1, f2) as gid " +
" FROM " + TABLE_NAME + " GROUP BY CUBE (f1, f2)";
String groupingSetsQuery =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
" GROUPING_ID(f1, f2) as gid " +
" FROM " + TABLE_NAME +
" GROUP BY GROUPING SETS ((f1, f2), (f1), (f2), ())";
compareSql(cubeQuery, groupingSetsQuery);
}
@Test
public void testRollupAsGroupingSets() throws Exception {
String rollupQuery =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
" GROUPING_ID(f1, f2) as gid " +
" FROM " + TABLE_NAME + " GROUP BY ROLLUP (f1, f2)";
String groupingSetsQuery =
"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
" GROUPING_ID(f1, f2) as gid " +
" FROM " + TABLE_NAME +
" GROUP BY GROUPING SETS ((f1, f2), (f1), ())";
compareSql(rollupQuery, groupingSetsQuery);
}
/**
* Execute SQL query and check results.
*
* @param query SQL query.
* @param expected Expected result.
*/
private void checkSql(String query, String expected) throws Exception {
Table resultTable = tableEnv.sql(query);
DataSet<Row> resultDataSet = tableEnv.toDataSet(resultTable, Row.class);
List<Row> results = resultDataSet.collect();
TestBaseUtils.compareResultAsText(results, expected);
}
private void compareSql(String query1, String query2) throws Exception {
// Function to map row to string
MapFunction<Row, String> mapFunction = new MapFunction<Row, String>() {
@Override
public String map(Row value) throws Exception {
return value == null ? "null" : value.toString();
}
};
// Execute first query and store results
Table resultTable1 = tableEnv.sql(query1);
DataSet<Row> resultDataSet1 = tableEnv.toDataSet(resultTable1, Row.class);
List<String> results1 = resultDataSet1.map(mapFunction).collect();
// Execute second query and store results
Table resultTable2 = tableEnv.sql(query2);
DataSet<Row> resultDataSet2 = tableEnv.toDataSet(resultTable2, Row.class);
List<String> results2 = resultDataSet2.map(mapFunction).collect();
// Compare results
TestBaseUtils.compareResultCollections(results1, results2, new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
return o2 == null ? o1 == null ? 0 : 1 : o1.compareTo(o2);
}
});
}
}
......@@ -245,19 +245,31 @@ class AggregationsITCase(
tEnv.sql(sqlQuery).toDataSet[Row]
}
@Test(expected = classOf[TableException])
@Test
def testGroupingSetAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery = "SELECT _2, _3, avg(_1) as a FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
val sqlQuery =
"SELECT _2, _3, avg(_1) as a, GROUP_ID() as g FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
val ds = CollectionDataSets.get3TupleDataSet(env)
tEnv.registerDataSet("MyTable", ds)
// must fail. grouping sets are not supported
tEnv.sql(sqlQuery).toDataSet[Row]
val result = tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expected =
"6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
"null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
"null,Hello world, how are you?,4,2\nnull,Hello world,3,2\nnull,Hello,2,2\n" +
"null,Comment#9,15,2\nnull,Comment#8,14,2\nnull,Comment#7,13,2\n" +
"null,Comment#6,12,2\nnull,Comment#5,11,2\nnull,Comment#4,10,2\n" +
"null,Comment#3,9,2\nnull,Comment#2,8,2\nnull,Comment#15,21,2\n" +
"null,Comment#14,20,2\nnull,Comment#13,19,2\nnull,Comment#12,18,2\n" +
"null,Comment#11,17,2\nnull,Comment#10,16,2\nnull,Comment#1,7,2"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.api.scala.batch.sql
import org.apache.flink.api.scala._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.utils.TableTestBase
import org.apache.flink.table.utils.TableTestUtil._
import org.junit.Test
class GroupingSetsTest extends TableTestBase {
@Test
def testGroupingSets(): Unit = {
val util = batchTestUtil()
util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g FROM MyTable " +
"GROUP BY GROUPING SETS (b, c)"
val aggregate = unaryNode(
"DataSetCalc",
binaryNode(
"DataSetUnion",
unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "b"),
term("select", "b", "AVG(a) AS c")
),
unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "c"),
term("select", "c AS b", "AVG(a) AS c")
),
term("union", "b", "c", "i$b", "i$c", "a")
),
term("select",
"CASE(i$b, null, b) AS b",
"CASE(i$c, null, c) AS c",
"a",
"+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g") // GROUP_ID()
)
util.verifySql(sqlQuery, aggregate)
}
@Test
def testCube(): Unit = {
val util = batchTestUtil()
util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
"GROUPING(b) as gb, GROUPING(c) as gc, " +
"GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
"GROUPING_ID(b, c) as gid " +
"FROM MyTable " +
"GROUP BY CUBE (b, c)"
val group1 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "b, c"),
term("select", "b", "c",
"AVG(a) AS i$b")
)
val group2 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "b"),
term("select", "b",
"AVG(a) AS c")
)
val group3 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "c"),
term("select", "c AS b",
"AVG(a) AS c")
)
val group4 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("select",
"AVG(a) AS b")
)
val union1 = binaryNode(
"DataSetUnion",
group1, group2,
term("union", "b", "c", "i$b", "i$c", "a")
)
val union2 = binaryNode(
"DataSetUnion",
union1, group3,
term("union", "b", "c", "i$b", "i$c", "a")
)
val union3 = binaryNode(
"DataSetUnion",
union2, group4,
term("union", "b", "c", "i$b", "i$c", "a")
)
val aggregate = unaryNode(
"DataSetCalc",
union3,
term("select",
"CASE(i$b, null, b) AS b",
"CASE(i$c, null, c) AS c",
"a",
"+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
"CASE(i$b, 1, 0) AS gb", // GROUPING(b)
"CASE(i$c, 1, 0) AS gc", // GROUPING(c)
"CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
"CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
"+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
)
util.verifySql(sqlQuery, aggregate)
}
@Test
def testRollup(): Unit = {
val util = batchTestUtil()
util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
"GROUPING(b) as gb, GROUPING(c) as gc, " +
"GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
"GROUPING_ID(b, c) as gid " + " FROM MyTable " +
"GROUP BY ROLLUP (b, c)"
val group1 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "b, c"),
term("select", "b", "c",
"AVG(a) AS i$b")
)
val group2 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("groupBy", "b"),
term("select", "b",
"AVG(a) AS c")
)
val group3 = unaryNode(
"DataSetAggregate",
batchTableNode(0),
term("select",
"AVG(a) AS b")
)
val union1 = binaryNode(
"DataSetUnion",
group1, group2,
term("union", "b", "c", "i$b", "i$c", "a")
)
val union2 = binaryNode(
"DataSetUnion",
union1, group3,
term("union", "b", "c", "i$b", "i$c", "a")
)
val aggregate = unaryNode(
"DataSetCalc",
union2,
term("select",
"CASE(i$b, null, b) AS b",
"CASE(i$c, null, c) AS c",
"a",
"+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
"CASE(i$b, 1, 0) AS gb", // GROUPING(b)
"CASE(i$c, 1, 0) AS gc", // GROUPING(c)
"CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
"CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
"+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
)
util.verifySql(sqlQuery, aggregate)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册