From dccdba199a8fbb8b5186f0952410c1b1b3dff14f Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Wed, 23 Aug 2017 17:54:10 -0700 Subject: [PATCH] [FLINK-7491] [table] Add MultiSet type and COLLECT aggregation function to SQL. This closes #4585. --- docs/dev/table/sql.md | 13 +- .../flink/api/java/typeutils/MapTypeInfo.java | 2 +- .../api/java/typeutils/MultisetTypeInfo.java | 91 +++++++ .../java/typeutils/MultisetTypeInfoTest.java | 38 +++ .../org/apache/flink/table/api/Types.scala | 11 +- .../table/calcite/FlinkTypeFactory.scala | 24 +- .../table/codegen/ExpressionReducer.scala | 9 +- .../aggfunctions/CollectAggFunction.scala | 122 ++++++++++ .../flink/table/plan/nodes/FlinkRelNode.scala | 2 +- .../plan/schema/MultisetRelDataType.scala | 50 ++++ .../runtime/aggregate/AggregateUtil.scala | 10 +- .../table/validate/FunctionCatalog.scala | 1 + .../aggfunctions/CollectAggFunctionTest.scala | 226 ++++++++++++++++++ .../runtime/batch/sql/AggregateITCase.scala | 29 +++ .../table/runtime/stream/sql/SqlITCase.scala | 59 +++++ 15 files changed, 677 insertions(+), 10 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java create mode 100644 flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md index 533aa6e9ecf..81dabeeb9b7 100644 --- a/docs/dev/table/sql.md +++ b/docs/dev/table/sql.md @@ -803,6 +803,7 @@ The SQL runtime is built on top of Flink's DataSet and DataStream APIs. Internal | `Types.PRIMITIVE_ARRAY`| `ARRAY` | e.g. `int[]` | | `Types.OBJECT_ARRAY` | `ARRAY` | e.g. `java.lang.Byte[]`| | `Types.MAP` | `MAP` | `java.util.HashMap` | +| `Types.MULTISET` | `MULTISET` | e.g. `java.util.HashMap` for a multiset of `String` | Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row. @@ -2164,6 +2165,17 @@ VAR_SAMP(value)

Returns the sample variance (square of the sample standard deviation) of the numeric field across all input values.

+ + + + {% highlight text %} + COLLECT(value) + {% endhighlight %} + + +

Returns a multiset of the values. null input value will be ignored. Return an empty multiset if only null values are added.

+ + @@ -2283,7 +2295,6 @@ The following functions are not supported yet: - Binary string operators and functions - System functions -- Collection functions - Distinct aggregate functions like COUNT DISTINCT {% top %} diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java index ca04e0cbc72..e9cd09dc217 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java @@ -93,7 +93,7 @@ public class MapTypeInfo extends TypeInformation> { @Override public int getTotalFields() { - return 2; + return 1; } @SuppressWarnings("unchecked") diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java new file mode 100644 index 00000000000..27fe70903ed --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.typeutils; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link TypeInformation} for the Multiset types of the Java API. + * + * @param The type of the elements in the Multiset. + */ +@PublicEvolving +public final class MultisetTypeInfo extends MapTypeInfo { + + private static final long serialVersionUID = 1L; + + public MultisetTypeInfo(Class elementTypeClass) { + super(elementTypeClass, Integer.class); + } + + public MultisetTypeInfo(TypeInformation elementTypeInfo) { + super(elementTypeInfo, BasicTypeInfo.INT_TYPE_INFO); + } + + // ------------------------------------------------------------------------ + // MultisetTypeInfo specific properties + // ------------------------------------------------------------------------ + + /** + * Gets the type information for the elements contained in the Multiset + */ + public TypeInformation getElementTypeInfo() { + return getKeyTypeInfo(); + } + + @Override + public String toString() { + return "Multiset<" + getKeyTypeInfo() + '>'; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + else if (obj instanceof MultisetTypeInfo) { + final MultisetTypeInfo other = (MultisetTypeInfo) obj; + return other.canEqual(this) && getKeyTypeInfo().equals(other.getKeyTypeInfo()); + } else { + return false; + } + } + + @Override + public int hashCode() { + return 31 * getKeyTypeInfo().hashCode() + 1; + } + + @Override + public boolean canEqual(Object obj) { + return obj != null && obj.getClass() == getClass(); + } + + @SuppressWarnings("unchecked") + @PublicEvolving + public static MultisetTypeInfo getInfoFor(TypeInformation componentInfo) { + checkNotNull(componentInfo); + + return new MultisetTypeInfo<>(componentInfo); + } +} diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java new file mode 100644 index 00000000000..395f4cef47a --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.typeutils; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.TypeInformationTestBase; + +/** + * Test for {@link MultisetTypeInfo}. + */ +public class MultisetTypeInfoTest extends TypeInformationTestBase> { + + @Override + protected MultisetTypeInfo[] getTestData() { + return new MultisetTypeInfo[] { + new MultisetTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO), + new MultisetTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), + new MultisetTypeInfo<>(Long.class) + }; + } +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala index 2152b727fff..100c22b368e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.api import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation, Types => JTypes} -import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo} import org.apache.flink.table.typeutils.TimeIntervalTypeInfo import org.apache.flink.types.Row @@ -110,4 +110,13 @@ object Types { def MAP(keyType: TypeInformation[_], valueType: TypeInformation[_]): TypeInformation[_] = { new MapTypeInfo(keyType, valueType) } + + /** + * Generates type information for a Multiset. + * + * @param elementType type of the elements of the multiset e.g. Types.STRING + */ + def MULTISET(elementType: TypeInformation[_]): TypeInformation[_] = { + new MultisetTypeInfo(elementType) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index 1cc9f6ba96f..768d70098ce 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -31,7 +31,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.ValueTypeInfo._ -import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName import org.apache.flink.table.plan.schema._ @@ -156,6 +156,13 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeFromTypeInfo(mp.getValueTypeInfo, isNullable = true), isNullable) + case mts: MultisetTypeInfo[_] => + new MultisetRelDataType( + mts, + createTypeFromTypeInfo(mts.getElementTypeInfo, isNullable = true), + isNullable + ) + case ti: TypeInformation[_] => new GenericRelDataType( ti, @@ -213,6 +220,14 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp canonize(relType) } + override def createMultisetType(elementType: RelDataType, maxCardinality: Long): RelDataType = { + val relType = new MultisetRelDataType( + MultisetTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)), + elementType, + isNullable = false) + canonize(relType) + } + override def createTypeWithNullability( relDataType: RelDataType, isNullable: Boolean): RelDataType = { @@ -234,6 +249,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp case map: MapRelDataType => new MapRelDataType(map.typeInfo, map.keyType, map.valueType, isNullable) + case multiSet: MultisetRelDataType => + new MultisetRelDataType(multiSet.typeInfo, multiSet.getComponentType, isNullable) + case generic: GenericRelDataType => new GenericRelDataType(generic.typeInfo, isNullable, typeSystem) @@ -403,6 +421,10 @@ object FlinkTypeFactory { val mapRelDataType = relDataType.asInstanceOf[MapRelDataType] mapRelDataType.typeInfo + case MULTISET if relDataType.isInstanceOf[MultisetRelDataType] => + val multisetRelDataType = relDataType.asInstanceOf[MultisetRelDataType] + multisetRelDataType.typeInfo + case _@t => throw TableException(s"Type is not supported: $t") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala index 3e71c99ee4b..9696ced32c4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/ExpressionReducer.scala @@ -74,7 +74,8 @@ class ExpressionReducer(config: TableConfig) case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) | (SqlTypeName.ARRAY, _) | - (SqlTypeName.MAP, _) => None + (SqlTypeName.MAP, _) | + (SqlTypeName.MULTISET, _) => None case (_, e) => Some(e) } @@ -112,7 +113,11 @@ class ExpressionReducer(config: TableConfig) val unreduced = constExprs.get(i) unreduced.getType.getSqlTypeName match { // we insert the original expression for object literals - case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY | SqlTypeName.MAP => + case SqlTypeName.ANY | + SqlTypeName.ROW | + SqlTypeName.ARRAY | + SqlTypeName.MAP | + SqlTypeName.MULTISET => reducedValues.add(unreduced) case _ => val reducedValue = reduced.getField(reducedIdx) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala new file mode 100644 index 00000000000..b10be61a166 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.aggfunctions + +import java.lang.{Iterable => JIterable} +import java.util + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils._ +import org.apache.flink.table.api.dataview.MapView +import org.apache.flink.table.dataview.MapViewTypeInfo +import org.apache.flink.table.functions.AggregateFunction + +import scala.collection.JavaConverters._ + +/** The initial accumulator for Collect aggregate function */ +class CollectAccumulator[E](var map: MapView[E, Integer]) { + def this() { + this(null) + } + + def canEqual(a: Any): Boolean = a.isInstanceOf[CollectAccumulator[E]] + + override def equals(that: Any): Boolean = + that match { + case that: CollectAccumulator[E] => that.canEqual(this) && this.map == that.map + case _ => false + } +} + +class CollectAggFunction[E](valueTypeInfo: TypeInformation[_]) + extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] { + + override def createAccumulator(): CollectAccumulator[E] = { + new CollectAccumulator[E]( + new MapView[E, Integer]( + valueTypeInfo.asInstanceOf[TypeInformation[E]], + BasicTypeInfo.INT_TYPE_INFO)) + } + + def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = { + if (value != null) { + val currVal = accumulator.map.get(value) + if (currVal != null) { + accumulator.map.put(value, currVal + 1) + } else { + accumulator.map.put(value, 1) + } + } + } + + override def getValue(accumulator: CollectAccumulator[E]): util.Map[E, Integer] = { + val iterator = accumulator.map.iterator + if (iterator.hasNext) { + val map = new util.HashMap[E, Integer]() + while (iterator.hasNext) { + val entry = iterator.next() + map.put(entry.getKey, entry.getValue) + } + map + } else { + Map[E, Integer]().asJava + } + } + + def resetAccumulator(acc: CollectAccumulator[E]): Unit = { + acc.map.clear() + } + + override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = { + val clazz = classOf[CollectAccumulator[E]] + val pojoFields = new util.ArrayList[PojoField] + pojoFields.add(new PojoField(clazz.getDeclaredField("map"), + new MapViewTypeInfo[E, Integer]( + valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))) + new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields) + } + + def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val mapViewIterator = iter.next().map.iterator + while (mapViewIterator.hasNext) { + val entry = mapViewIterator.next() + val k = entry.getKey + val oldValue = acc.map.get(k) + if (oldValue == null) { + acc.map.put(k, entry.getValue) + } else { + acc.map.put(k, entry.getValue + oldValue) + } + } + } + } + + def retract(acc: CollectAccumulator[E], value: E): Unit = { + if (value != null) { + val count = acc.map.get(value) + if (count == 1) { + acc.map.remove(value) + } else { + acc.map.put(value, count - 1) + } + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala index 8509a8ec9de..f3e1a626482 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRelNode.scala @@ -94,7 +94,7 @@ trait FlinkRelNode extends RelNode { case SqlTypeName.ARRAY => // 16 is an arbitrary estimate estimateDataTypeSize(t.getComponentType) * 16 - case SqlTypeName.MAP => + case SqlTypeName.MAP | SqlTypeName.MULTISET => // 16 is an arbitrary estimate (estimateDataTypeSize(t.getKeyType) + estimateDataTypeSize(t.getValueType)) * 16 case SqlTypeName.ANY => 128 // 128 is an arbitrary estimate diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala new file mode 100644 index 00000000000..859fc41daa0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.schema + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql.`type`.MultisetSqlType +import org.apache.flink.api.common.typeinfo.TypeInformation + +class MultisetRelDataType( + val typeInfo: TypeInformation[_], + elementType: RelDataType, + isNullable: Boolean) + extends MultisetSqlType( + elementType, + isNullable) { + + override def toString = s"MULTISET($elementType)" + + def canEqual(other: Any): Boolean = other.isInstanceOf[MultisetRelDataType] + + override def equals(other: Any): Boolean = other match { + case that: MultisetRelDataType => + super.equals(that) && + (that canEqual this) && + typeInfo == that.typeInfo && + isNullable == that.isNullable + case _ => false + } + + override def hashCode(): Int = { + typeInfo.hashCode() + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 58940d06abb..c84b254d28b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -28,7 +28,7 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind} import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo} import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} @@ -1200,8 +1200,8 @@ object AggregateUtil { } else { aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray } - val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType - .getSqlTypeName + val relDataType = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType + val sqlTypeName = relDataType.getSqlTypeName aggregateCall.getAggregation match { case _: SqlSumAggFunction => @@ -1410,6 +1410,10 @@ object AggregateUtil { case _: SqlCountAggFunction => aggregates(index) = new CountAggFunction + case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT => + aggregates(index) = new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType)) + accTypes(index) = aggregates(index).getAccumulatorType + case udagg: AggSqlFunction => aggregates(index) = udagg.getFunction accTypes(index) = udagg.accType diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 5254ceb61e5..3398a930f61 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -319,6 +319,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.SUM, SqlStdOperatorTable.SUM0, SqlStdOperatorTable.COUNT, + SqlStdOperatorTable.COLLECT, SqlStdOperatorTable.MIN, SqlStdOperatorTable.MAX, SqlStdOperatorTable.AVG, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala new file mode 100644 index 00000000000..f85cb70a562 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggfunctions + +import java.util + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.java.typeutils.GenericTypeInfo +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.aggfunctions._ + +import scala.collection.JavaConverters._ + +/** + * Test case for built-in collect aggregate functions + */ +class StringCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[String, Integer], CollectAccumulator[String]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[String, Integer]] = { + val map = new util.HashMap[String, Integer]() + map.put("a", 2) + map.put("b", 1) + map.put("c", 1) + map.put("d", 1) + map.put("e", 1) + map.put("f", 1) + Seq(map, Map[String, Integer]().asJava) + } + + override def aggregator: AggregateFunction[ + util.Map[String, Integer], CollectAccumulator[String]] = + new CollectAggFunction(BasicTypeInfo.STRING_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class IntCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[Int, Integer], CollectAccumulator[Int]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1, 1, 2, null, 3, null, 4, 5, null), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Int, Integer]] = { + val map = new util.HashMap[Int, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Int, Integer]().asJava) + } + + override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] = + new CollectAggFunction(BasicTypeInfo.INT_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class ByteCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[Byte, Integer], CollectAccumulator[Byte]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1.toByte, 1.toByte, 2.toByte, null, 3.toByte, null, 4.toByte, 5.toByte, null), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Byte, Integer]] = { + val map = new util.HashMap[Byte, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Byte, Integer]().asJava) + } + + override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] = + new CollectAggFunction(BasicTypeInfo.BYTE_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class ShortCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[Short, Integer], CollectAccumulator[Short]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1.toShort, 1.toShort, 2.toShort, null, + 3.toShort, null, 4.toShort, 5.toShort, null), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Short, Integer]] = { + val map = new util.HashMap[Short, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Short, Integer]().asJava) + } + + override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] = + new CollectAggFunction(BasicTypeInfo.SHORT_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class LongCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[Long, Integer], CollectAccumulator[Long]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1L, 1L, 2L, null, 3L, null, 4L, 5L, null), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Long, Integer]] = { + val map = new util.HashMap[Long, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Long, Integer]().asJava) + } + + override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] = + new CollectAggFunction(BasicTypeInfo.LONG_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class FloatAggFunctionTest + extends AggFunctionTestBase[util.Map[Float, Integer], CollectAccumulator[Float]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1f, 1f, 2f, null, 3.2f, null, 4f, 5f, null), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Float, Integer]] = { + val map = new util.HashMap[Float, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3.2f, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Float, Integer]().asJava) + } + + override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] = + new CollectAggFunction(BasicTypeInfo.FLOAT_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class DoubleAggFunctionTest + extends AggFunctionTestBase[util.Map[Double, Integer], CollectAccumulator[Double]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Double, Integer]] = { + val map = new util.HashMap[Double, Integer]() + map.put(1, 2) + map.put(2, 1) + map.put(3.2d, 1) + map.put(4, 1) + map.put(5, 1) + Seq(map, Map[Double, Integer]().asJava) + } + + override def aggregator: AggregateFunction[ + util.Map[Double, Integer], CollectAccumulator[Double]] = + new CollectAggFunction(BasicTypeInfo.DOUBLE_TYPE_INFO) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class ObjectCollectAggFunctionTest + extends AggFunctionTestBase[util.Map[Object, Integer], CollectAccumulator[Object]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(Tuple2(1, "a"), Tuple2(1, "a"), null, Tuple2(2, "b")), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[util.Map[Object, Integer]] = { + val map = new util.HashMap[Object, Integer]() + map.put(Tuple2(1, "a"), 2) + map.put(Tuple2(2, "b"), 1) + Seq(map, Map[Object, Integer]().asJava) + } + + override def aggregator: AggregateFunction[ + util.Map[Object, Integer], CollectAccumulator[Object]] = + new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object])) + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala index 465a88ccf92..aa934c685a8 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala @@ -328,6 +328,35 @@ class AggregateITCase( TestBaseUtils.compareResultAsText(result.asJava, expected) } + @Test + def testTumbleWindowAggregateWithCollect(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = + "SELECT b, COLLECT(b)" + + "FROM T " + + "GROUP BY b, TUMBLE(ts, INTERVAL '3' SECOND)" + + val ds = CollectionDataSets.get3TupleDataSet(env) + // create timestamps + .map(x => (x._1, x._2, x._3, new Timestamp(x._1 * 1000))) + tEnv.registerDataSet("T", ds, 'a, 'b, 'c, 'ts) + + val result = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expected = Seq( + "1,{1=1}", + "2,{2=1}", "2,{2=1}", + "3,{3=1}", "3,{3=2}", + "4,{4=2}", "4,{4=2}", + "5,{5=1}", "5,{5=1}", "5,{5=3}", + "6,{6=1}", "6,{6=2}", "6,{6=3}" + ).mkString("\n") + + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + @Test def testHopWindowAggregate(): Unit = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 2c82d9c2905..32e37243a9f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -92,6 +92,65 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + @Test + def testUnboundedGroupByCollect(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + StreamITCase.clear + + val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b" + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sql(sqlQuery).toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = List( + "1,{1=1}", + "2,{2=1, 3=1}", + "3,{4=1, 5=1, 6=1}", + "4,{7=1, 8=1, 9=1, 10=1}", + "5,{11=1, 12=1, 13=1, 14=1, 15=1}", + "6,{16=1, 17=1, 18=1, 19=1, 20=1, 21=1}") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test + def testUnboundedGroupByCollectWithObject(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + StreamITCase.clear + + val sqlQuery = "SELECT b, COLLECT(c) FROM MyTable GROUP BY b" + + val data = List( + (1, 1, (12, "45.6")), + (2, 2, (12, "45.612")), + (3, 2, (13, "41.6")), + (4, 3, (14, "45.2136")), + (5, 3, (18, "42.6")) + ) + + tEnv.registerTable("MyTable", + env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)) + + val result = tEnv.sql(sqlQuery).toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = List( + "1,{(12,45.6)=1}", + "2,{(13,41.6)=1, (12,45.612)=1}", + "3,{(18,42.6)=1, (14,45.2136)=1}") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + /** test selection **/ @Test def testSelectExpressionFromTable(): Unit = { -- GitLab