提交 17dd915e 编写于 作者: Z Zhenghua Gao 提交者: twalthr

[FLINK-6124] [table] support max/min aggregations for string type

This closes #3579.
上级 86d32ac8
......@@ -155,3 +155,11 @@ class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] {
override def getInitValue = BigDecimal.ZERO
override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
/**
* Built-in String Max aggregate function
*/
class StringMaxAggFunction extends MaxAggFunction[String] {
override def getInitValue = "".toString
override def getValueTypeInfo = BasicTypeInfo.STRING_TYPE_INFO
}
......@@ -155,3 +155,11 @@ class DecimalMinAggFunction extends MinAggFunction[BigDecimal] {
override def getInitValue: BigDecimal = BigDecimal.ZERO
override def getValueTypeInfo = BasicTypeInfo.BIG_DEC_TYPE_INFO
}
/**
* Built-in String Min aggregate function
*/
class StringMinAggFunction extends MinAggFunction[String] {
override def getInitValue = "".toString
override def getValueTypeInfo = BasicTypeInfo.STRING_TYPE_INFO
}
......@@ -916,6 +916,8 @@ object AggregateUtil {
new DecimalMinAggFunction
case BOOLEAN =>
new BooleanMinAggFunction
case VARCHAR | CHAR =>
new StringMinAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Min aggregate does no support type:" + sqlType)
}
......@@ -961,6 +963,8 @@ object AggregateUtil {
new DecimalMaxAggFunction
case BOOLEAN =>
new BooleanMaxAggFunction
case VARCHAR | CHAR =>
new StringMaxAggFunction
case sqlType: SqlTypeName =>
throw new TableException("Max aggregate does no support type:" + sqlType)
}
......
......@@ -92,36 +92,13 @@ class AggregationsITCase(
}
@Test
def testWorkingAggregationDataTypes(): Unit = {
def testAggregationDataTypes(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery =
"SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7), " +
" sum(CAST(_6 AS DECIMAL))" +
"FROM MyTable"
val ds = env.fromElements(
(1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, "Hello"),
(2: Byte, 2: Short, 2, 2L, 2.0f, 2.0d, "Ciao"))
tEnv.registerDataSet("MyTable", ds)
val result = tEnv.sql(sqlQuery)
val expected = "1,1,1,1,1.5,1.5,2,3.0"
val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testTableWorkingAggregationDataTypes(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery = "SELECT avg(a), avg(b), avg(c), avg(d), avg(e), avg(f), count(g)" +
"FROM MyTable"
val sqlQuery = "SELECT avg(a), avg(b), avg(c), avg(d), avg(e), avg(f), count(g), " +
"min(g), min('Ciao'), max(g), max('Ciao'), sum(CAST(f AS DECIMAL)) FROM MyTable"
val ds = env.fromElements(
(1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, "Hello"),
......@@ -130,7 +107,7 @@ class AggregationsITCase(
val result = tEnv.sql(sqlQuery)
val expected = "1,1,1,1,1.5,1.5,2"
val expected = "1,1,1,1,1.5,1.5,2,Ciao,Ciao,Hello,Ciao,3.0"
val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
......
......@@ -192,3 +192,36 @@ class DecimalMaxAggFunctionTest extends AggFunctionTestBase[BigDecimal] {
override def supportRetraction: Boolean = false
}
class StringMaxAggFunctionTest extends AggFunctionTestBase[String] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new String("a"),
new String("b"),
new String("c"),
null.asInstanceOf[String],
new String("d")
),
Seq(
null.asInstanceOf[String],
null.asInstanceOf[String],
null.asInstanceOf[String]
),
Seq(
new String("1House"),
new String("Household"),
new String("house"),
new String("household")
)
)
override def expectedResults: Seq[String] = Seq(
new String("d"),
null.asInstanceOf[String],
new String("household")
)
override def aggregator: AggregateFunction[String] = new StringMaxAggFunction()
override def supportRetraction: Boolean = false
}
......@@ -192,3 +192,36 @@ class DecimalMinAggFunctionTest extends AggFunctionTestBase[BigDecimal] {
override def supportRetraction: Boolean = false
}
class StringMinAggFunctionTest extends AggFunctionTestBase[String] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new String("a"),
new String("b"),
new String("c"),
null.asInstanceOf[String],
new String("d")
),
Seq(
null.asInstanceOf[String],
null.asInstanceOf[String],
null.asInstanceOf[String]
),
Seq(
new String("1House"),
new String("Household"),
new String("house"),
new String("household")
)
)
override def expectedResults: Seq[String] = Seq(
new String("a"),
null.asInstanceOf[String],
new String("1House")
)
override def aggregator: AggregateFunction[String] = new StringMinAggFunction()
override def supportRetraction: Boolean = false
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册