提交 20fe2af8 编写于 作者: C chengxiang li 提交者: Aljoscha Krettek

[FLINK-3087] [Table API] support multi count in aggregation.

上级 9215b724
......@@ -123,14 +123,17 @@ abstract class ExpressionCodeGenerator[R](
}
}
val cleanedExpr = expr match {
case expressions.Naming(namedExpr, _) => namedExpr
case _ => expr
def cleanedExpr(e: Expression): Expression = {
e match {
case expressions.Naming(namedExpr, _) => cleanedExpr(namedExpr)
case _ => e
}
}
val resultTpe = typeTermForTypeInfo(cleanedExpr.typeInfo)
val code: String = cleanedExpr match {
val cleanedExpression = cleanedExpr(expr)
val resultTpe = typeTermForTypeInfo(cleanedExpression.typeInfo)
val code: String = cleanedExpression match {
case expressions.Literal(null, typeInfo) =>
if (nullCheck) {
......
......@@ -51,20 +51,23 @@ object ExpandAggregations {
val aggregationIntermediates = mutable.HashMap[Aggregation, Seq[Expression]]()
var intermediateCount = 0
var resultCount = 0
selection foreach { f =>
f.transformPre {
case agg: Aggregation =>
val intermediateReferences = agg.getIntermediateFields.zip(agg.getAggregations) map {
case (expr, basicAgg) =>
resultCount += 1
val resultName = s"result.$resultCount"
aggregations.get((expr, basicAgg)) match {
case Some(intermediateName) =>
ResolvedFieldReference(intermediateName, expr.typeInfo)
Naming(ResolvedFieldReference(intermediateName, expr.typeInfo), resultName)
case None =>
intermediateCount = intermediateCount + 1
val intermediateName = s"intermediate.$intermediateCount"
intermediateFields += Naming(expr, intermediateName)
aggregations((expr, basicAgg)) = intermediateName
ResolvedFieldReference(intermediateName, expr.typeInfo)
Naming(ResolvedFieldReference(intermediateName, expr.typeInfo), resultName)
}
}
......
......@@ -137,6 +137,29 @@ public class AggregationsITCase extends MultipleProgramsTestBase {
compareResultAsText(results, expected);
}
@Test
public void testAggregationWithTwoCount() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSource<Tuple2<Float, String>> input =
env.fromElements(
new Tuple2<>(1f, "Hello"),
new Tuple2<>(2f, "Ciao"));
Table table =
tableEnv.fromDataSet(input);
Table result =
table.select("f0.count, f1.count");
DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "2,2";
compareResultAsText(results, expected);
}
@Test(expected = ExpressionException.class)
public void testNonWorkingDataTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
......
......@@ -80,6 +80,17 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testAggregationWithTwoCount(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements((1f, "Hello"), (2f, "Ciao")).toTable
.select('_1.count, '_2.count).toDataSet[Row]
val expected = "2,2"
val results = ds.collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test(expected = classOf[ExpressionException])
def testNonWorkingAggregationDataTypes(): Unit = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册