提交 dccdba19 编写于 作者: S Shuyi Chen 提交者: Fabian Hueske

[FLINK-7491] [table] Add MultiSet type and COLLECT aggregation function to SQL.

This closes #4585.
上级 4047be49
......@@ -803,6 +803,7 @@ The SQL runtime is built on top of Flink's DataSet and DataStream APIs. Internal
| `Types.PRIMITIVE_ARRAY`| `ARRAY` | e.g. `int[]` |
| `Types.OBJECT_ARRAY` | `ARRAY` | e.g. `java.lang.Byte[]`|
| `Types.MAP` | `MAP` | `java.util.HashMap` |
| `Types.MULTISET` | `MULTISET` | e.g. `java.util.HashMap<String, Integer>` for a multiset of `String` |
Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row.
......@@ -2164,6 +2165,17 @@ VAR_SAMP(value)
<p>Returns the sample variance (square of the sample standard deviation) of the numeric field across all input values.</p>
</td>
</tr>
<tr>
<td>
{% highlight text %}
COLLECT(value)
{% endhighlight %}
</td>
<td>
<p>Returns a multiset of the <i>value</i>s. null input <i>value</i> will be ignored. Return an empty multiset if only null values are added. </p>
</td>
</tr>
</tbody>
</table>
......@@ -2283,7 +2295,6 @@ The following functions are not supported yet:
- Binary string operators and functions
- System functions
- Collection functions
- Distinct aggregate functions like COUNT DISTINCT
{% top %}
......
......@@ -93,7 +93,7 @@ public class MapTypeInfo<K, V> extends TypeInformation<Map<K, V>> {
@Override
public int getTotalFields() {
return 2;
return 1;
}
@SuppressWarnings("unchecked")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.typeutils;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* A {@link TypeInformation} for the Multiset types of the Java API.
*
* @param <T> The type of the elements in the Multiset.
*/
@PublicEvolving
public final class MultisetTypeInfo<T> extends MapTypeInfo<T, Integer> {
private static final long serialVersionUID = 1L;
public MultisetTypeInfo(Class<T> elementTypeClass) {
super(elementTypeClass, Integer.class);
}
public MultisetTypeInfo(TypeInformation<T> elementTypeInfo) {
super(elementTypeInfo, BasicTypeInfo.INT_TYPE_INFO);
}
// ------------------------------------------------------------------------
// MultisetTypeInfo specific properties
// ------------------------------------------------------------------------
/**
* Gets the type information for the elements contained in the Multiset
*/
public TypeInformation<T> getElementTypeInfo() {
return getKeyTypeInfo();
}
@Override
public String toString() {
return "Multiset<" + getKeyTypeInfo() + '>';
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
else if (obj instanceof MultisetTypeInfo) {
final MultisetTypeInfo<?> other = (MultisetTypeInfo<?>) obj;
return other.canEqual(this) && getKeyTypeInfo().equals(other.getKeyTypeInfo());
} else {
return false;
}
}
@Override
public int hashCode() {
return 31 * getKeyTypeInfo().hashCode() + 1;
}
@Override
public boolean canEqual(Object obj) {
return obj != null && obj.getClass() == getClass();
}
@SuppressWarnings("unchecked")
@PublicEvolving
public static <C> MultisetTypeInfo<C> getInfoFor(TypeInformation<C> componentInfo) {
checkNotNull(componentInfo);
return new MultisetTypeInfo<>(componentInfo);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.typeutils;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.TypeInformationTestBase;
/**
* Test for {@link MultisetTypeInfo}.
*/
public class MultisetTypeInfoTest extends TypeInformationTestBase<MultisetTypeInfo<?>> {
@Override
protected MultisetTypeInfo<?>[] getTestData() {
return new MultisetTypeInfo<?>[] {
new MultisetTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO),
new MultisetTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
new MultisetTypeInfo<>(Long.class)
};
}
}
......@@ -18,7 +18,7 @@
package org.apache.flink.table.api
import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation, Types => JTypes}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.types.Row
......@@ -110,4 +110,13 @@ object Types {
def MAP(keyType: TypeInformation[_], valueType: TypeInformation[_]): TypeInformation[_] = {
new MapTypeInfo(keyType, valueType)
}
/**
* Generates type information for a Multiset.
*
* @param elementType type of the elements of the multiset e.g. Types.STRING
*/
def MULTISET(elementType: TypeInformation[_]): TypeInformation[_] = {
new MultisetTypeInfo(elementType)
}
}
......@@ -31,7 +31,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName
import org.apache.flink.table.plan.schema._
......@@ -156,6 +156,13 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
createTypeFromTypeInfo(mp.getValueTypeInfo, isNullable = true),
isNullable)
case mts: MultisetTypeInfo[_] =>
new MultisetRelDataType(
mts,
createTypeFromTypeInfo(mts.getElementTypeInfo, isNullable = true),
isNullable
)
case ti: TypeInformation[_] =>
new GenericRelDataType(
ti,
......@@ -213,6 +220,14 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
canonize(relType)
}
override def createMultisetType(elementType: RelDataType, maxCardinality: Long): RelDataType = {
val relType = new MultisetRelDataType(
MultisetTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)),
elementType,
isNullable = false)
canonize(relType)
}
override def createTypeWithNullability(
relDataType: RelDataType,
isNullable: Boolean): RelDataType = {
......@@ -234,6 +249,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
case map: MapRelDataType =>
new MapRelDataType(map.typeInfo, map.keyType, map.valueType, isNullable)
case multiSet: MultisetRelDataType =>
new MultisetRelDataType(multiSet.typeInfo, multiSet.getComponentType, isNullable)
case generic: GenericRelDataType =>
new GenericRelDataType(generic.typeInfo, isNullable, typeSystem)
......@@ -403,6 +421,10 @@ object FlinkTypeFactory {
val mapRelDataType = relDataType.asInstanceOf[MapRelDataType]
mapRelDataType.typeInfo
case MULTISET if relDataType.isInstanceOf[MultisetRelDataType] =>
val multisetRelDataType = relDataType.asInstanceOf[MultisetRelDataType]
multisetRelDataType.typeInfo
case _@t =>
throw TableException(s"Type is not supported: $t")
}
......
......@@ -74,7 +74,8 @@ class ExpressionReducer(config: TableConfig)
case (SqlTypeName.ANY, _) |
(SqlTypeName.ROW, _) |
(SqlTypeName.ARRAY, _) |
(SqlTypeName.MAP, _) => None
(SqlTypeName.MAP, _) |
(SqlTypeName.MULTISET, _) => None
case (_, e) => Some(e)
}
......@@ -112,7 +113,11 @@ class ExpressionReducer(config: TableConfig)
val unreduced = constExprs.get(i)
unreduced.getType.getSqlTypeName match {
// we insert the original expression for object literals
case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY | SqlTypeName.MAP =>
case SqlTypeName.ANY |
SqlTypeName.ROW |
SqlTypeName.ARRAY |
SqlTypeName.MAP |
SqlTypeName.MULTISET =>
reducedValues.add(unreduced)
case _ =>
val reducedValue = reduced.getField(reducedIdx)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.functions.aggfunctions
import java.lang.{Iterable => JIterable}
import java.util
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils._
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.functions.AggregateFunction
import scala.collection.JavaConverters._
/** The initial accumulator for Collect aggregate function */
class CollectAccumulator[E](var map: MapView[E, Integer]) {
def this() {
this(null)
}
def canEqual(a: Any): Boolean = a.isInstanceOf[CollectAccumulator[E]]
override def equals(that: Any): Boolean =
that match {
case that: CollectAccumulator[E] => that.canEqual(this) && this.map == that.map
case _ => false
}
}
class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] {
override def createAccumulator(): CollectAccumulator[E] = {
new CollectAccumulator[E](
new MapView[E, Integer](
valueTypeInfo.asInstanceOf[TypeInformation[E]],
BasicTypeInfo.INT_TYPE_INFO))
}
def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val currVal = accumulator.map.get(value)
if (currVal != null) {
accumulator.map.put(value, currVal + 1)
} else {
accumulator.map.put(value, 1)
}
}
}
override def getValue(accumulator: CollectAccumulator[E]): util.Map[E, Integer] = {
val iterator = accumulator.map.iterator
if (iterator.hasNext) {
val map = new util.HashMap[E, Integer]()
while (iterator.hasNext) {
val entry = iterator.next()
map.put(entry.getKey, entry.getValue)
}
map
} else {
Map[E, Integer]().asJava
}
}
def resetAccumulator(acc: CollectAccumulator[E]): Unit = {
acc.map.clear()
}
override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = {
val clazz = classOf[CollectAccumulator[E]]
val pojoFields = new util.ArrayList[PojoField]
pojoFields.add(new PojoField(clazz.getDeclaredField("map"),
new MapViewTypeInfo[E, Integer](
valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)))
new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields)
}
def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = {
val iter = its.iterator()
while (iter.hasNext) {
val mapViewIterator = iter.next().map.iterator
while (mapViewIterator.hasNext) {
val entry = mapViewIterator.next()
val k = entry.getKey
val oldValue = acc.map.get(k)
if (oldValue == null) {
acc.map.put(k, entry.getValue)
} else {
acc.map.put(k, entry.getValue + oldValue)
}
}
}
}
def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val count = acc.map.get(value)
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
}
}
}
}
......@@ -94,7 +94,7 @@ trait FlinkRelNode extends RelNode {
case SqlTypeName.ARRAY =>
// 16 is an arbitrary estimate
estimateDataTypeSize(t.getComponentType) * 16
case SqlTypeName.MAP =>
case SqlTypeName.MAP | SqlTypeName.MULTISET =>
// 16 is an arbitrary estimate
(estimateDataTypeSize(t.getKeyType) + estimateDataTypeSize(t.getValueType)) * 16
case SqlTypeName.ANY => 128 // 128 is an arbitrary estimate
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.schema
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.`type`.MultisetSqlType
import org.apache.flink.api.common.typeinfo.TypeInformation
class MultisetRelDataType(
val typeInfo: TypeInformation[_],
elementType: RelDataType,
isNullable: Boolean)
extends MultisetSqlType(
elementType,
isNullable) {
override def toString = s"MULTISET($elementType)"
def canEqual(other: Any): Boolean = other.isInstanceOf[MultisetRelDataType]
override def equals(other: Any): Boolean = other match {
case that: MultisetRelDataType =>
super.equals(that) &&
(that canEqual this) &&
typeInfo == that.typeInfo &&
isNullable == that.isNullable
case _ => false
}
override def hashCode(): Int = {
typeInfo.hashCode()
}
}
......@@ -28,7 +28,7 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
......@@ -1200,8 +1200,8 @@ object AggregateUtil {
} else {
aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
}
val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
.getSqlTypeName
val relDataType = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
val sqlTypeName = relDataType.getSqlTypeName
aggregateCall.getAggregation match {
case _: SqlSumAggFunction =>
......@@ -1410,6 +1410,10 @@ object AggregateUtil {
case _: SqlCountAggFunction =>
aggregates(index) = new CountAggFunction
case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
aggregates(index) = new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType))
accTypes(index) = aggregates(index).getAccumulatorType
case udagg: AggSqlFunction =>
aggregates(index) = udagg.getFunction
accTypes(index) = udagg.accType
......
......@@ -319,6 +319,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.SUM0,
SqlStdOperatorTable.COUNT,
SqlStdOperatorTable.COLLECT,
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.aggfunctions
import java.util
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.aggfunctions._
import scala.collection.JavaConverters._
/**
* Test case for built-in collect aggregate functions
*/
class StringCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[String, Integer], CollectAccumulator[String]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[String, Integer]] = {
val map = new util.HashMap[String, Integer]()
map.put("a", 2)
map.put("b", 1)
map.put("c", 1)
map.put("d", 1)
map.put("e", 1)
map.put("f", 1)
Seq(map, Map[String, Integer]().asJava)
}
override def aggregator: AggregateFunction[
util.Map[String, Integer], CollectAccumulator[String]] =
new CollectAggFunction(BasicTypeInfo.STRING_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class IntCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Int, Integer], CollectAccumulator[Int]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1, 1, 2, null, 3, null, 4, 5, null),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Int, Integer]] = {
val map = new util.HashMap[Int, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Int, Integer]().asJava)
}
override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] =
new CollectAggFunction(BasicTypeInfo.INT_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class ByteCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Byte, Integer], CollectAccumulator[Byte]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1.toByte, 1.toByte, 2.toByte, null, 3.toByte, null, 4.toByte, 5.toByte, null),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Byte, Integer]] = {
val map = new util.HashMap[Byte, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Byte, Integer]().asJava)
}
override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] =
new CollectAggFunction(BasicTypeInfo.BYTE_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class ShortCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Short, Integer], CollectAccumulator[Short]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1.toShort, 1.toShort, 2.toShort, null,
3.toShort, null, 4.toShort, 5.toShort, null),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Short, Integer]] = {
val map = new util.HashMap[Short, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Short, Integer]().asJava)
}
override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] =
new CollectAggFunction(BasicTypeInfo.SHORT_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class LongCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Long, Integer], CollectAccumulator[Long]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1L, 1L, 2L, null, 3L, null, 4L, 5L, null),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Long, Integer]] = {
val map = new util.HashMap[Long, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Long, Integer]().asJava)
}
override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] =
new CollectAggFunction(BasicTypeInfo.LONG_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class FloatAggFunctionTest
extends AggFunctionTestBase[util.Map[Float, Integer], CollectAccumulator[Float]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1f, 1f, 2f, null, 3.2f, null, 4f, 5f, null),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Float, Integer]] = {
val map = new util.HashMap[Float, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3.2f, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Float, Integer]().asJava)
}
override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] =
new CollectAggFunction(BasicTypeInfo.FLOAT_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class DoubleAggFunctionTest
extends AggFunctionTestBase[util.Map[Double, Integer], CollectAccumulator[Double]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Double, Integer]] = {
val map = new util.HashMap[Double, Integer]()
map.put(1, 2)
map.put(2, 1)
map.put(3.2d, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, Map[Double, Integer]().asJava)
}
override def aggregator: AggregateFunction[
util.Map[Double, Integer], CollectAccumulator[Double]] =
new CollectAggFunction(BasicTypeInfo.DOUBLE_TYPE_INFO)
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
class ObjectCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Object, Integer], CollectAccumulator[Object]] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(Tuple2(1, "a"), Tuple2(1, "a"), null, Tuple2(2, "b")),
Seq(null, null, null, null, null, null)
)
override def expectedResults: Seq[util.Map[Object, Integer]] = {
val map = new util.HashMap[Object, Integer]()
map.put(Tuple2(1, "a"), 2)
map.put(Tuple2(2, "b"), 1)
Seq(map, Map[Object, Integer]().asJava)
}
override def aggregator: AggregateFunction[
util.Map[Object, Integer], CollectAccumulator[Object]] =
new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object]))
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
......@@ -328,6 +328,35 @@ class AggregateITCase(
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
def testTumbleWindowAggregateWithCollect(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery =
"SELECT b, COLLECT(b)" +
"FROM T " +
"GROUP BY b, TUMBLE(ts, INTERVAL '3' SECOND)"
val ds = CollectionDataSets.get3TupleDataSet(env)
// create timestamps
.map(x => (x._1, x._2, x._3, new Timestamp(x._1 * 1000)))
tEnv.registerDataSet("T", ds, 'a, 'b, 'c, 'ts)
val result = tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expected = Seq(
"1,{1=1}",
"2,{2=1}", "2,{2=1}",
"3,{3=1}", "3,{3=2}",
"4,{4=2}", "4,{4=2}",
"5,{5=1}", "5,{5=1}", "5,{5=3}",
"6,{6=1}", "6,{6=2}", "6,{6=3}"
).mkString("\n")
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
def testHopWindowAggregate(): Unit = {
......
......@@ -92,6 +92,65 @@ class SqlITCase extends StreamingWithStateTestBase {
assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
}
@Test
def testUnboundedGroupByCollect(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStateBackend(getStateBackend)
StreamITCase.clear
val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b"
val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
tEnv.registerTable("MyTable", t)
val result = tEnv.sql(sqlQuery).toRetractStream[Row]
result.addSink(new StreamITCase.RetractingSink).setParallelism(1)
env.execute()
val expected = List(
"1,{1=1}",
"2,{2=1, 3=1}",
"3,{4=1, 5=1, 6=1}",
"4,{7=1, 8=1, 9=1, 10=1}",
"5,{11=1, 12=1, 13=1, 14=1, 15=1}",
"6,{16=1, 17=1, 18=1, 19=1, 20=1, 21=1}")
assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
}
@Test
def testUnboundedGroupByCollectWithObject(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStateBackend(getStateBackend)
StreamITCase.clear
val sqlQuery = "SELECT b, COLLECT(c) FROM MyTable GROUP BY b"
val data = List(
(1, 1, (12, "45.6")),
(2, 2, (12, "45.612")),
(3, 2, (13, "41.6")),
(4, 3, (14, "45.2136")),
(5, 3, (18, "42.6"))
)
tEnv.registerTable("MyTable",
env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c))
val result = tEnv.sql(sqlQuery).toRetractStream[Row]
result.addSink(new StreamITCase.RetractingSink).setParallelism(1)
env.execute()
val expected = List(
"1,{(12,45.6)=1}",
"2,{(13,41.6)=1, (12,45.612)=1}",
"3,{(18,42.6)=1, (14,45.2136)=1}")
assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
}
/** test selection **/
@Test
def testSelectExpressionFromTable(): Unit = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册