From 37defbb428fc6fbf2f425bdfa03bafca8b2b6a5e Mon Sep 17 00:00:00 2001 From: twalthr Date: Mon, 9 May 2016 12:10:57 +0200 Subject: [PATCH] [FLINK-3859] [table] Add BigDecimal/BigInteger support to Table API This closes #2088. --- docs/apis/table.md | 4 +- .../flink/api/scala/table/expressionDsl.scala | 23 ++ .../flink/api/table/FlinkTypeSystem.scala | 36 ++ .../flink/api/table/TableEnvironment.scala | 1 + .../api/table/codegen/CodeGenUtils.scala | 48 ++- .../api/table/codegen/CodeGenerator.scala | 88 +++-- .../table/codegen/calls/BuiltInMethods.scala | 3 + .../codegen/calls/FloorCeilCallGen.scala | 5 +- .../table/codegen/calls/ScalarFunctions.scala | 15 + .../table/codegen/calls/ScalarOperators.scala | 314 ++++++++++++------ .../table/expressions/ExpressionParser.scala | 29 +- .../api/table/expressions/arithmetic.scala | 42 +-- .../api/table/expressions/comparison.scala | 34 +- .../api/table/expressions/literals.scala | 12 +- .../table/plan/nodes/dataset/DataSetRel.scala | 1 + .../runtime/aggregate/AggregateUtil.scala | 8 + .../runtime/aggregate/AvgAggregate.scala | 45 ++- .../runtime/aggregate/MaxAggregate.scala | 44 +++ .../runtime/aggregate/MinAggregate.scala | 43 +++ .../runtime/aggregate/SumAggregate.scala | 43 +++ .../api/table/typeutils/TypeCheckUtils.scala | 28 +- .../api/table/typeutils/TypeCoercion.scala | 16 +- .../api/table/typeutils/TypeConverter.scala | 2 + .../scala/batch/sql/AggregationsITCase.scala | 6 +- .../scala/batch/table/ExpressionsITCase.scala | 17 + .../table/expressions/DecimalTypeTest.scala | 311 +++++++++++++++++ .../expressions/ScalarFunctionsTest.scala | 182 +++++----- .../utils/ExpressionEvaluator.scala | 119 ------- .../utils/ExpressionTestBase.scala | 186 +++++++++++ .../runtime/aggregate/AggregateTestBase.scala | 10 +- .../runtime/aggregate/AvgAggregateTest.scala | 30 ++ .../runtime/aggregate/MaxAggregateTest.scala | 34 ++ .../runtime/aggregate/MinAggregateTest.scala | 34 ++ .../runtime/aggregate/SumAggregateTest.scala | 38 +++ 34 files changed, 1407 insertions(+), 444 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeSystem.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/DecimalTypeTest.scala delete mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionEvaluator.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala diff --git a/docs/apis/table.md b/docs/apis/table.md index f0a65286e78..2ccee9af5e9 100644 --- a/docs/apis/table.md +++ b/docs/apis/table.md @@ -752,7 +752,7 @@ suffixed = cast | as | aggregation | nullCheck | evaluate | functionCall ; cast = composite , ".cast(" , dataType , ")" ; -dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" ; +dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" | "DECIMAL"; as = composite , ".as(" , fieldReference , ")" ; @@ -773,6 +773,8 @@ nullLiteral = "Null(" , dataType , ")" ; Here, `literal` is a valid Java literal, `fieldReference` specifies a column in the data, and `functionIdentifier` specifies a supported scalar function. The column names and function names follow Java identifier syntax. Expressions specified as Strings can also use prefix notation instead of suffix notation to call operators and functions. +If working with exact numeric values or large decimals is required, the Table API also supports Java's BigDecimal type. In the Scala Table API decimals can be defined by `BigDecimal("123456")` and in Java by appending a "p" for precise e.g. `123456p`. + {% top %} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala index 11fb64a40b4..93952abfa6b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala @@ -220,6 +220,14 @@ trait ImplicitExpressionConversions { def expr = Literal(l) } + implicit class LiteralByteExpression(b: Byte) extends ImplicitExpressionOperations { + def expr = Literal(b) + } + + implicit class LiteralShortExpression(s: Short) extends ImplicitExpressionOperations { + def expr = Literal(s) + } + implicit class LiteralIntExpression(i: Int) extends ImplicitExpressionOperations { def expr = Literal(i) } @@ -240,11 +248,26 @@ trait ImplicitExpressionConversions { def expr = Literal(bool) } + implicit class LiteralJavaDecimalExpression(javaDecimal: java.math.BigDecimal) + extends ImplicitExpressionOperations { + def expr = Literal(javaDecimal) + } + + implicit class LiteralScalaDecimalExpression(scalaDecimal: scala.math.BigDecimal) + extends ImplicitExpressionOperations { + def expr = Literal(scalaDecimal.bigDecimal) + } + implicit def symbol2FieldExpression(sym: Symbol): Expression = UnresolvedFieldReference(sym.name) + implicit def byte2Literal(b: Byte): Expression = Literal(b) + implicit def short2Literal(s: Short): Expression = Literal(s) implicit def int2Literal(i: Int): Expression = Literal(i) implicit def long2Literal(l: Long): Expression = Literal(l) implicit def double2Literal(d: Double): Expression = Literal(d) implicit def float2Literal(d: Float): Expression = Literal(d) implicit def string2Literal(str: String): Expression = Literal(str) implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool) + implicit def javaDec2Literal(javaDec: java.math.BigDecimal): Expression = Literal(javaDec) + implicit def scalaDec2Literal(scalaDec: scala.math.BigDecimal): Expression = + Literal(scalaDec.bigDecimal) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeSystem.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeSystem.scala new file mode 100644 index 00000000000..df6022a0c10 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeSystem.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table + +import org.apache.calcite.rel.`type`.RelDataTypeSystemImpl + +/** + * Custom type system for Flink. + */ +class FlinkTypeSystem extends RelDataTypeSystemImpl { + + // we cannot use Int.MaxValue because of an overflow in Calcites type inference logic + // half should be enough for all use cases + override def getMaxNumericScale: Int = Int.MaxValue / 2 + + // we cannot use Int.MaxValue because of an overflow in Calcites type inference logic + // half should be enough for all use cases + override def getMaxNumericPrecision: Int = Int.MaxValue / 2 + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala index 4d1bb1d3a6b..0f6cb242a8f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala @@ -69,6 +69,7 @@ abstract class TableEnvironment(val config: TableConfig) { .defaultSchema(tables) .parserConfig(parserConfig) .costFactory(new DataSetCostFactory) + .typeSystem(new FlinkTypeSystem) .build // the builder for Calcite RelNodes, Calcite's representation of a relational expression tree. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala index a24d74fb787..79e51e53a16 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala @@ -23,11 +23,11 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._ -import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{FractionalTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.typeutils.{TypeExtractor, PojoTypeInfo, TupleTypeInfo} +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo, TypeExtractor} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeCheckUtils} object CodeGenUtils { @@ -97,11 +97,24 @@ object CodeGenUtils { case _ => "null" } - def requireNumeric(genExpr: GeneratedExpression) = genExpr.resultType match { - case nti: NumericTypeInfo[_] => // ok - case _ => throw new CodeGenException("Numeric expression type expected.") + def superPrimitive(typeInfo: TypeInformation[_]): String = typeInfo match { + case _: FractionalTypeInfo[_] => "double" + case _ => "long" } + // ---------------------------------------------------------------------------------------------- + + def requireNumeric(genExpr: GeneratedExpression) = + if (!TypeCheckUtils.isNumeric(genExpr.resultType)) { + throw new CodeGenException("Numeric expression type expected, but was " + + s"'${genExpr.resultType}'.") + } + + def requireComparable(genExpr: GeneratedExpression) = + if (!TypeCheckUtils.isComparable(genExpr.resultType)) { + throw new CodeGenException(s"Comparable type expected, but was '${genExpr.resultType}'.") + } + def requireString(genExpr: GeneratedExpression) = genExpr.resultType match { case STRING_TYPE_INFO => // ok case _ => throw new CodeGenException("String expression type expected.") @@ -112,6 +125,8 @@ object CodeGenUtils { case _ => throw new CodeGenException("Boolean expression type expected.") } + // ---------------------------------------------------------------------------------------------- + def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) def isReference(typeInfo: TypeInformation[_]): Boolean = typeInfo match { @@ -126,27 +141,6 @@ object CodeGenUtils { case _ => true } - def isNumeric(genExpr: GeneratedExpression): Boolean = isNumeric(genExpr.resultType) - - def isNumeric(typeInfo: TypeInformation[_]): Boolean = typeInfo match { - case nti: NumericTypeInfo[_] => true - case _ => false - } - - def isString(genExpr: GeneratedExpression): Boolean = isString(genExpr.resultType) - - def isString(typeInfo: TypeInformation[_]): Boolean = typeInfo match { - case STRING_TYPE_INFO => true - case _ => false - } - - def isBoolean(genExpr: GeneratedExpression): Boolean = isBoolean(genExpr.resultType) - - def isBoolean(typeInfo: TypeInformation[_]): Boolean = typeInfo match { - case BOOLEAN_TYPE_INFO => true - case _ => false - } - // ---------------------------------------------------------------------------------------------- sealed abstract class FieldAccessor diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index c8d6dca5ec5..5e0ac58e609 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -18,10 +18,12 @@ package org.apache.flink.api.table.codegen +import java.math.{BigDecimal => JBigDecimal} + import org.apache.calcite.rex._ -import org.apache.calcite.sql.{SqlLiteral, SqlOperator} import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.fun.SqlStdOperatorTable._ +import org.apache.calcite.sql.{SqlLiteral, SqlOperator} import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction} import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType @@ -32,8 +34,9 @@ import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.Indenter.toISC import org.apache.flink.api.table.codegen.calls.ScalarFunctions import org.apache.flink.api.table.codegen.calls.ScalarOperators._ -import org.apache.flink.api.table.typeutils.{TypeConverter, RowTypeInfo} -import TypeConverter.sqlTypeToTypeInfo +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString} +import org.apache.flink.api.table.typeutils.TypeConverter.sqlTypeToTypeInfo import scala.collection.JavaConversions._ import scala.collection.mutable @@ -542,7 +545,7 @@ class CodeGenerator( case BOOLEAN => generateNonNullLiteral(resultType, literal.getValue3.toString) case TINYINT => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + val decimal = BigDecimal(value.asInstanceOf[JBigDecimal]) if (decimal.isValidByte) { generateNonNullLiteral(resultType, decimal.byteValue().toString) } @@ -550,7 +553,7 @@ class CodeGenerator( throw new CodeGenException("Decimal can not be converted to byte.") } case SMALLINT => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + val decimal = BigDecimal(value.asInstanceOf[JBigDecimal]) if (decimal.isValidShort) { generateNonNullLiteral(resultType, decimal.shortValue().toString) } @@ -558,7 +561,7 @@ class CodeGenerator( throw new CodeGenException("Decimal can not be converted to short.") } case INTEGER => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + val decimal = BigDecimal(value.asInstanceOf[JBigDecimal]) if (decimal.isValidInt) { generateNonNullLiteral(resultType, decimal.intValue().toString) } @@ -566,29 +569,36 @@ class CodeGenerator( throw new CodeGenException("Decimal can not be converted to integer.") } case BIGINT => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + val decimal = BigDecimal(value.asInstanceOf[JBigDecimal]) if (decimal.isValidLong) { - generateNonNullLiteral(resultType, decimal.longValue().toString) + generateNonNullLiteral(resultType, decimal.longValue().toString + "L") } else { throw new CodeGenException("Decimal can not be converted to long.") } case FLOAT => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) - if (decimal.isValidFloat) { - generateNonNullLiteral(resultType, decimal.floatValue().toString + "f") - } - else { - throw new CodeGenException("Decimal can not be converted to float.") + val floatValue = value.asInstanceOf[JBigDecimal].floatValue() + floatValue match { + case Float.NaN => generateNonNullLiteral(resultType, "java.lang.Float.NaN") + case Float.NegativeInfinity => + generateNonNullLiteral(resultType, "java.lang.Float.NEGATIVE_INFINITY") + case Float.PositiveInfinity => + generateNonNullLiteral(resultType, "java.lang.Float.POSITIVE_INFINITY") + case _ => generateNonNullLiteral(resultType, floatValue.toString + "f") } case DOUBLE => - val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) - if (decimal.isValidDouble) { - generateNonNullLiteral(resultType, decimal.doubleValue().toString) - } - else { - throw new CodeGenException("Decimal can not be converted to double.") + val doubleValue = value.asInstanceOf[JBigDecimal].doubleValue() + doubleValue match { + case Double.NaN => generateNonNullLiteral(resultType, "java.lang.Double.NaN") + case Double.NegativeInfinity => + generateNonNullLiteral(resultType, "java.lang.Double.NEGATIVE_INFINITY") + case Double.PositiveInfinity => + generateNonNullLiteral(resultType, "java.lang.Double.POSITIVE_INFINITY") + case _ => generateNonNullLiteral(resultType, doubleValue.toString + "d") } + case DECIMAL => + val decimalField = addReusableDecimal(value.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(resultType, decimalField) case VARCHAR | CHAR => generateNonNullLiteral(resultType, "\"" + value.toString + "\"") case SYMBOL => @@ -630,7 +640,7 @@ class CodeGenerator( val left = operands.head val right = operands(1) requireString(left) - generateArithmeticOperator("+", nullCheck, resultType, left, right) + generateStringConcatOperator(nullCheck, left, right) case MINUS if isNumeric(resultType) => val left = operands.head @@ -674,37 +684,39 @@ class CodeGenerator( case EQUALS => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) generateEquals(nullCheck, left, right) case NOT_EQUALS => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) generateNotEquals(nullCheck, left, right) case GREATER_THAN => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) + requireComparable(left) + requireComparable(right) generateComparison(">", nullCheck, left, right) case GREATER_THAN_OR_EQUAL => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) + requireComparable(left) + requireComparable(right) generateComparison(">=", nullCheck, left, right) case LESS_THAN => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) + requireComparable(left) + requireComparable(right) generateComparison("<", nullCheck, left, right) case LESS_THAN_OR_EQUAL => val left = operands.head val right = operands(1) - checkNumericOrString(left, right) + requireComparable(left) + requireComparable(right) generateComparison("<=", nullCheck, left, right) case IS_NULL => @@ -775,14 +787,6 @@ class CodeGenerator( // generator helping methods // ---------------------------------------------------------------------------------------------- - def checkNumericOrString(left: GeneratedExpression, right: GeneratedExpression): Unit = { - if (isNumeric(left)) { - requireNumeric(right) - } else if (isString(left)) { - requireString(right) - } - } - private def generateInputAccess( inputType: TypeInformation[Any], inputTerm: String, @@ -1036,4 +1040,18 @@ class CodeGenerator( fieldTerm } + def addReusableDecimal(decimal: JBigDecimal): String = decimal match { + case JBigDecimal.ZERO => "java.math.BigDecimal.ZERO" + case JBigDecimal.ONE => "java.math.BigDecimal.ONE" + case JBigDecimal.TEN => "java.math.BigDecimal.TEN" + case _ => + val fieldTerm = newName("decimal") + val fieldDecimal = + s""" + |transient java.math.BigDecimal $fieldTerm = + | new java.math.BigDecimal("${decimal.toString}"); + |""".stripMargin + reusableMemberStatements.add(fieldDecimal) + fieldTerm + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/BuiltInMethods.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/BuiltInMethods.scala index 080e1ba32fe..c3fbed39dbc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/BuiltInMethods.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/BuiltInMethods.scala @@ -17,6 +17,8 @@ */ package org.apache.flink.api.table.codegen.calls +import java.math.{BigDecimal => JBigDecimal} + import org.apache.calcite.linq4j.tree.Types import org.apache.calcite.runtime.SqlFunctions @@ -26,4 +28,5 @@ object BuiltInMethods { val POWER = Types.lookupMethod(classOf[Math], "pow", classOf[Double], classOf[Double]) val LN = Types.lookupMethod(classOf[Math], "log", classOf[Double]) val ABS = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[Double]) + val ABS_DEC = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[JBigDecimal]) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FloorCeilCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FloorCeilCallGen.scala index 4bac718b78d..84f60a0d47d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FloorCeilCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FloorCeilCallGen.scala @@ -20,7 +20,8 @@ package org.apache.flink.api.table.codegen.calls import java.lang.reflect.Method -import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO} +import org.apache.flink.api.common.typeinfo.BasicTypeInfo. + {DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO,BIG_DEC_TYPE_INFO} import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression} /** @@ -33,7 +34,7 @@ class FloorCeilCallGen(method: Method) extends MultiTypeMethodCallGen(method) { operands: Seq[GeneratedExpression]) : GeneratedExpression = { operands.head.resultType match { - case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO => + case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO | BIG_DEC_TYPE_INFO => super.generate(codeGenerator, operands) case _ => operands.head // no floor/ceil necessary diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala index 9046a77582c..1462d9cc291 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala @@ -142,16 +142,31 @@ object ScalarFunctions { Seq(DOUBLE_TYPE_INFO), new MultiTypeMethodCallGen(BuiltInMethods.ABS)) + addSqlFunction( + ABS, + Seq(BIG_DEC_TYPE_INFO), + new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC)) + addSqlFunction( FLOOR, Seq(DOUBLE_TYPE_INFO), new FloorCeilCallGen(BuiltInMethod.FLOOR.method)) + addSqlFunction( + FLOOR, + Seq(BIG_DEC_TYPE_INFO), + new FloorCeilCallGen(BuiltInMethod.FLOOR.method)) + addSqlFunction( CEIL, Seq(DOUBLE_TYPE_INFO), new FloorCeilCallGen(BuiltInMethod.CEIL.method)) + addSqlFunction( + CEIL, + Seq(BIG_DEC_TYPE_INFO), + new FloorCeilCallGen(BuiltInMethod.CEIL.method)) + // ---------------------------------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala index 182b8432780..189096d8866 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala @@ -18,12 +18,23 @@ package org.apache.flink.api.table.codegen.calls import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression} +import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isBoolean, isComparable, isDecimal, isNumeric} object ScalarOperators { + def generateStringConcatOperator( + nullCheck: Boolean, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + generateOperatorIfNotNull(nullCheck, STRING_TYPE_INFO, left, right) { + (leftTerm, rightTerm) => s"$leftTerm + $rightTerm" + } + } + def generateArithmeticOperator( operator: String, nullCheck: Boolean, @@ -31,40 +42,17 @@ object ScalarOperators { left: GeneratedExpression, right: GeneratedExpression) : GeneratedExpression = { - // String arithmetic // TODO rework - if (isString(left)) { - generateOperatorIfNotNull(nullCheck, resultType, left, right) { - (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" - } - } - // Numeric arithmetic - else if (isNumeric(left) && isNumeric(right)) { - val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]] - val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]] - val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) + val leftCasting = numericCasting(left.resultType, resultType) + val rightCasting = numericCasting(right.resultType, resultType) + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) - generateOperatorIfNotNull(nullCheck, resultType, left, right) { + generateOperatorIfNotNull(nullCheck, resultType, left, right) { (leftTerm, rightTerm) => - // no casting required - if (leftType == resultType && rightType == resultType) { - s"$leftTerm $operator $rightTerm" - } - // left needs casting - else if (leftType != resultType && rightType == resultType) { - s"(($resultTypeTerm) $leftTerm) $operator $rightTerm" - } - // right needs casting - else if (leftType == resultType && rightType != resultType) { - s"$leftTerm $operator (($resultTypeTerm) $rightTerm)" - } - // both sides need casting - else { - s"(($resultTypeTerm) $leftTerm) $operator (($resultTypeTerm) $rightTerm)" + if (isDecimal(resultType)) { + s"${leftCasting(leftTerm)}.${arithOpToDecMethod(operator)}(${rightCasting(rightTerm)})" + } else { + s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})" } - } - } - else { - throw new CodeGenException("Unsupported arithmetic operation.") } } @@ -75,7 +63,16 @@ object ScalarOperators { operand: GeneratedExpression) : GeneratedExpression = { generateUnaryOperatorIfNotNull(nullCheck, resultType, operand) { - (operandTerm) => s"$operator($operandTerm)" + (operandTerm) => + if (isDecimal(operand.resultType) && operator == "-") { + s"$operandTerm.negate()" + } else if (isDecimal(operand.resultType) && operator == "+") { + s"$operandTerm" + } else if (isNumeric(operand.resultType)) { + s"$operator($operandTerm)" + } else { + throw new CodeGenException("Unsupported unary operator.") + } } } @@ -84,15 +81,27 @@ object ScalarOperators { left: GeneratedExpression, right: GeneratedExpression) : GeneratedExpression = { - generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { - if (isReference(left)) { - (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" - } - else if (isReference(right)) { - (leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)" - } - else { - (leftTerm, rightTerm) => s"$leftTerm == $rightTerm" + // numeric types + if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison("==", nullCheck, left, right) + } + // comparable types of same type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison("==", nullCheck, left, right) + } + // non comparable types + else { + generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } } } } @@ -102,19 +111,34 @@ object ScalarOperators { left: GeneratedExpression, right: GeneratedExpression) : GeneratedExpression = { - generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { - if (isReference(left)) { - (leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" - } - else if (isReference(right)) { - (leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))" - } - else { - (leftTerm, rightTerm) => s"$leftTerm != $rightTerm" + // numeric types + if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison("!=", nullCheck, left, right) + } + // comparable types + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison("!=", nullCheck, left, right) + } + // non comparable types + else { + generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } } } } + /** + * Generates comparison code for numeric types and comparable types of same type. + */ def generateComparison( operator: String, nullCheck: Boolean, @@ -122,14 +146,38 @@ object ScalarOperators { right: GeneratedExpression) : GeneratedExpression = { generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { - if (isString(left) && isString(right)) { - (leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0" + // left is decimal or both sides are decimal + if (isDecimal(left.resultType) && isNumeric(right.resultType)) { + (leftTerm, rightTerm) => { + val operandCasting = numericCasting(right.resultType, left.resultType) + s"$leftTerm.compareTo(${operandCasting(rightTerm)}) $operator 0" + } + } + // right is decimal + else if (isNumeric(left.resultType) && isDecimal(right.resultType)) { + (leftTerm, rightTerm) => { + val operandCasting = numericCasting(left.resultType, right.resultType) + s"${operandCasting(leftTerm)}.compareTo($rightTerm) $operator 0" + } } - else if (isNumeric(left) && isNumeric(right)) { + // both sides are numeric + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" } + // both sides are boolean + else if (isBoolean(left.resultType) && left.resultType == right.resultType) { + operator match { + case "==" | "!=" => (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + case _ => throw new CodeGenException(s"Unsupported boolean comparison '$operator'.") + } + } + // both sides are same comparable type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0" + } else { - throw new CodeGenException("Comparison is only supported for Strings and numeric types.") + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") } } } @@ -147,7 +195,7 @@ object ScalarOperators { |boolean $nullTerm = false; |""".stripMargin } - else if (!nullCheck && isReference(operand.resultType)) { + else if (!nullCheck && isReference(operand)) { s""" |${operand.code} |boolean $resultTerm = ${operand.resultTerm} == null; @@ -177,7 +225,7 @@ object ScalarOperators { |boolean $nullTerm = false; |""".stripMargin } - else if (!nullCheck && isReference(operand.resultType)) { + else if (!nullCheck && isReference(operand)) { s""" |${operand.code} |boolean $resultTerm = ${operand.resultTerm} != null; @@ -326,63 +374,72 @@ object ScalarOperators { nullCheck: Boolean, operand: GeneratedExpression, targetType: TypeInformation[_]) - : GeneratedExpression = { - targetType match { - // identity casting - case operand.resultType => - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"$operandTerm" - } + : GeneratedExpression = (operand.resultType, targetType) match { + // identity casting + case (fromTp, toTp) if fromTp == toTp => + operand + + // * -> String + case (_, STRING_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s""" "" + $operandTerm""" + } - // * -> String - case STRING_TYPE_INFO => - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s""" "" + $operandTerm""" - } + // * -> Character + case (_, CHAR_TYPE_INFO) => + throw new CodeGenException("Character type not supported.") - // * -> Date - case DATE_TYPE_INFO => - throw new CodeGenException("Date type not supported yet.") + // String -> NUMERIC TYPE (not Character), Boolean + case (STRING_TYPE_INFO, _: NumericTypeInfo[_]) + | (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) => + val wrapperClass = targetType.getTypeClass.getCanonicalName + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$wrapperClass.valueOf($operandTerm)" + } - // * -> Void - case VOID_TYPE_INFO => - throw new CodeGenException("Void type not supported.") + // String -> BigDecimal + case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) => + val wrapperClass = targetType.getTypeClass.getCanonicalName + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"new $wrapperClass($operandTerm)" + } - // * -> Character - case CHAR_TYPE_INFO => - throw new CodeGenException("Character type not supported.") + // Boolean -> NUMERIC TYPE + case (BOOLEAN_TYPE_INFO, nti: NumericTypeInfo[_]) => + val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" + } - // NUMERIC TYPE -> Boolean - case BOOLEAN_TYPE_INFO if isNumeric(operand) => - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"$operandTerm != 0" - } + // Boolean -> BigDecimal + case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$operandTerm ? java.math.BigDecimal.ONE : java.math.BigDecimal.ZERO" + } - // String -> BASIC TYPE (not String, Date, Void, Character) - case ti: BasicTypeInfo[_] if isString(operand) => - val wrapperClass = targetType.getTypeClass.getCanonicalName - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"$wrapperClass.valueOf($operandTerm)" - } + // NUMERIC TYPE -> Boolean + case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$operandTerm != 0" + } - // NUMERIC TYPE -> NUMERIC TYPE - case nti: NumericTypeInfo[_] if isNumeric(operand) => - val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"($targetTypeTerm) $operandTerm" - } + // BigDecimal -> Boolean + case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$operandTerm.compareTo(java.math.BigDecimal.ZERO) != 0" + } - // Boolean -> NUMERIC TYPE - case nti: NumericTypeInfo[_] if isBoolean(operand) => - val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) - generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" - } + // NUMERIC TYPE, BigDecimal -> NUMERIC TYPE, BigDecimal + case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) + | (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_]) + | (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => + val operandCasting = numericCasting(operand.resultType, targetType) + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"${operandCasting(operandTerm)}" + } - case _ => - throw new CodeGenException(s"Unsupported cast from '${operand.resultType}'" + - s"to '$targetType'.") - } + case (from, to) => + throw new CodeGenException(s"Unsupported cast from '$from' to '$to'.") } def generateIfElse( @@ -519,4 +576,51 @@ object ScalarOperators { GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) } + private def arithOpToDecMethod(operator: String): String = operator match { + case "+" => "add" + case "-" => "subtract" + case "*" => "multiply" + case "/" => "divide" + case "%" => "remainder" + case _ => throw new CodeGenException("Unsupported decimal arithmetic operator.") + } + + private def numericCasting( + operandType: TypeInformation[_], + resultType: TypeInformation[_]) + : (String) => String = { + + def decToPrimMethod(targetType: TypeInformation[_]): String = targetType match { + case BYTE_TYPE_INFO => "byteValueExact" + case SHORT_TYPE_INFO => "shortValueExact" + case INT_TYPE_INFO => "intValueExact" + case LONG_TYPE_INFO => "longValueExact" + case FLOAT_TYPE_INFO => "floatValue" + case DOUBLE_TYPE_INFO => "doubleValue" + case _ => throw new CodeGenException("Unsupported decimal casting type.") + } + + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) + // no casting necessary + if (operandType == resultType) { + (operandTerm) => s"$operandTerm" + } + // result type is decimal but numeric operand is not + else if (isDecimal(resultType) && !isDecimal(operandType) && isNumeric(operandType)) { + (operandTerm) => + s"java.math.BigDecimal.valueOf((${superPrimitive(operandType)}) $operandTerm)" + } + // numeric result type is not decimal but operand is + else if (isNumeric(resultType) && !isDecimal(resultType) && isDecimal(operandType) ) { + (operandTerm) => s"$operandTerm.${decToPrimMethod(resultType)}()" + } + // result type and operand type are numeric but not decimal + else if (isNumeric(operandType) && isNumeric(resultType) + && !isDecimal(operandType) && !isDecimal(resultType)) { + (operandTerm) => s"(($resultTypeTerm) $operandTerm)" + } + else { + throw new CodeGenException(s"Unsupported casting from $operandType to $resultType.") + } + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala index db3d18764a2..2d3f611fb03 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala @@ -71,23 +71,24 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { "DOUBLE" ^^ { ti => BasicTypeInfo.DOUBLE_TYPE_INFO } | ("BOOL" | "BOOLEAN" ) ^^ { ti => BasicTypeInfo.BOOLEAN_TYPE_INFO } | "STRING" ^^ { ti => BasicTypeInfo.STRING_TYPE_INFO } | - "DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO } + "DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO } | + "DECIMAL" ^^ { ti => BasicTypeInfo.BIG_DEC_TYPE_INFO } // Literals lazy val numberLiteral: PackratParser[Expression] = - ((wholeNumber <~ ("L" | "l")) | floatingPointNumber | decimalNumber | wholeNumber) ^^ { - str => - if (str.endsWith("L") || str.endsWith("l")) { - Literal(str.toLong) - } else if (str.matches("""-?\d+""")) { - Literal(str.toInt) - } else if (str.endsWith("f") | str.endsWith("F")) { - Literal(str.toFloat) - } else { - Literal(str.toDouble) - } - } + (wholeNumber <~ ("l" | "L")) ^^ { n => Literal(n.toLong) } | + (decimalNumber <~ ("p" | "P")) ^^ { n => Literal(BigDecimal(n)) } | + (floatingPointNumber | decimalNumber) ^^ { + n => + if (n.matches("""-?\d+""")) { + Literal(n.toInt) + } else if (n.endsWith("f") || n.endsWith("F")) { + Literal(n.toFloat) + } else { + Literal(n.toDouble) + } + } lazy val singleQuoteStringLiteral: Parser[Expression] = ("'" + """([^'\p{Cntrl}\\]|\\[\\'"bfnrt]|\\u[a-fA-F0-9]{4})*""" + "'").r ^^ { @@ -261,7 +262,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val unaryMinus: PackratParser[Expression] = "-" ~> composite ^^ { e => UnaryMinus(e) } - lazy val unary = unaryNot | unaryMinus | composite + lazy val unary = composite | unaryNot | unaryMinus // arithmetic diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala index 0ce4685eb54..4fa5a7138ce 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala @@ -17,18 +17,17 @@ */ package org.apache.flink.api.table.expressions -import scala.collection.JavaConversions._ - import org.apache.calcite.rex.RexNode -import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.tools.RelBuilder - -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation} -import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion, TypeConverter} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString} +import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion} import org.apache.flink.api.table.validate._ +import scala.collection.JavaConversions._ + abstract class BinaryArithmetic extends BinaryExpression { def sqlOperator: SqlOperator @@ -45,9 +44,8 @@ abstract class BinaryArithmetic extends BinaryExpression { // TODO: tighten this rule once we implemented type coercion rules during validation override def validateInput(): ExprValidationResult = { - if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] || - !right.resultType.isInstanceOf[NumericTypeInfo[_]]) { - ValidationFailure(s"$this requires both operands Numeric, get" + + if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) { + ValidationFailure(s"$this requires both operands Numeric, got " + s"${left.resultType} and ${right.resultType}") } else { ValidationSuccess @@ -61,28 +59,24 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic { val sqlOperator = SqlStdOperatorTable.PLUS override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { - val l = left.toRexNode - val r = right.toRexNode - if(SqlTypeName.STRING_TYPES.contains(l.getType.getSqlTypeName)) { - val cast: RexNode = relBuilder.cast(r, - TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) - relBuilder.call(SqlStdOperatorTable.PLUS, l, cast) - } else if(SqlTypeName.STRING_TYPES.contains(r.getType.getSqlTypeName)) { - val cast: RexNode = relBuilder.cast(l, - TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) - relBuilder.call(SqlStdOperatorTable.PLUS, cast, r) + if(isString(left.resultType)) { + val castedRight = Cast(right, BasicTypeInfo.STRING_TYPE_INFO) + relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, castedRight.toRexNode) + } else if(isString(right.resultType)) { + val castedLeft = Cast(left, BasicTypeInfo.STRING_TYPE_INFO) + relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, right.toRexNode) } else { - relBuilder.call(SqlStdOperatorTable.PLUS, l, r) + val castedLeft = Cast(left, resultType) + val castedRight = Cast(right, resultType) + relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, castedRight.toRexNode) } } // TODO: tighten this rule once we implemented type coercion rules during validation override def validateInput(): ExprValidationResult = { - if (left.resultType == BasicTypeInfo.STRING_TYPE_INFO || - right.resultType == BasicTypeInfo.STRING_TYPE_INFO) { + if (isString(left.resultType) || isString(right.resultType)) { ValidationSuccess - } else if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] || - !right.resultType.isInstanceOf[NumericTypeInfo[_]]) { + } else if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) { ValidationFailure(s"$this requires Numeric or String input," + s" get ${left.resultType} and ${right.resultType}") } else { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala index 63caeaa43d3..4d67f8e6550 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala @@ -17,17 +17,16 @@ */ package org.apache.flink.api.table.expressions -import scala.collection.JavaConversions._ - import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.tools.RelBuilder - import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.NumericTypeInfo +import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isComparable, isNumeric} import org.apache.flink.api.table.validate._ +import scala.collection.JavaConversions._ + abstract class BinaryComparison extends BinaryExpression { def sqlOperator: SqlOperator @@ -39,11 +38,12 @@ abstract class BinaryComparison extends BinaryExpression { // TODO: tighten this rule once we implemented type coercion rules during validation override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { - case (STRING_TYPE_INFO, STRING_TYPE_INFO) => ValidationSuccess - case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess + case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess + case (lType, rType) if isComparable(lType) && lType == rType => ValidationSuccess case (lType, rType) => ValidationFailure( - s"Comparison is only supported for Strings and numeric types, get $lType and $rType") + s"Comparison is only supported for numeric types and comparable types of same type," + + s"got $lType and $rType") } } @@ -53,13 +53,11 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val sqlOperator: SqlOperator = SqlStdOperatorTable.EQUALS override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { - case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess + case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess + // TODO widen this rule once we support custom objects as types (FLINK-3916) + case (lType, rType) if lType == rType => ValidationSuccess case (lType, rType) => - if (lType != rType) { - ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType") - } else { - ValidationSuccess - } + ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType") } } @@ -69,13 +67,11 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari val sqlOperator: SqlOperator = SqlStdOperatorTable.NOT_EQUALS override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { - case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess + case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess + // TODO widen this rule once we support custom objects as types (FLINK-3916) + case (lType, rType) if lType == rType => ValidationSuccess case (lType, rType) => - if (lType != rType) { - ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType") - } else { - ValidationSuccess - } + ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType") } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala index 9caec26bd55..1574178d90d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.flink.api.table.expressions import java.util.Date import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.table.typeutils.TypeConverter @@ -35,6 +36,9 @@ object Literal { case str: String => Literal(str, BasicTypeInfo.STRING_TYPE_INFO) case bool: Boolean => Literal(bool, BasicTypeInfo.BOOLEAN_TYPE_INFO) case date: Date => Literal(date, BasicTypeInfo.DATE_TYPE_INFO) + case javaDec: java.math.BigDecimal => Literal(javaDec, BasicTypeInfo.BIG_DEC_TYPE_INFO) + case scalaDec: scala.math.BigDecimal => + Literal(scalaDec.bigDecimal, BasicTypeInfo.BIG_DEC_TYPE_INFO) } } @@ -42,7 +46,13 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre override def toString = s"$value" override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { - relBuilder.literal(value) + resultType match { + case BasicTypeInfo.BIG_DEC_TYPE_INFO => + val bigDecValue = value.asInstanceOf[java.math.BigDecimal] + val decType = relBuilder.getTypeFactory.createSqlType(SqlTypeName.DECIMAL) + relBuilder.getRexBuilder.makeExactLiteral(bigDecValue, decType) + case _ => relBuilder.literal(value) + } } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala index 946dfc0f6c6..08e0c418e2c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala @@ -61,6 +61,7 @@ trait DataSetRel extends RelNode with FlinkRel { case SqlTypeName.DOUBLE => s + 8 case SqlTypeName.VARCHAR => s + 12 case SqlTypeName.CHAR => s + 1 + case SqlTypeName.DECIMAL => s + 12 case _ => throw new TableException("Unsupported data type encountered") } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala index 82364ebf640..44a67b65349 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala @@ -155,6 +155,8 @@ object AggregateUtil { new FloatSumAggregate case DOUBLE => new DoubleSumAggregate + case DECIMAL => + new DecimalSumAggregate case sqlType: SqlTypeName => throw new TableException("Sum aggregate does no support type:" + sqlType) } @@ -173,6 +175,8 @@ object AggregateUtil { new FloatAvgAggregate case DOUBLE => new DoubleAvgAggregate + case DECIMAL => + new DecimalAvgAggregate case sqlType: SqlTypeName => throw new TableException("Avg aggregate does no support type:" + sqlType) } @@ -192,6 +196,8 @@ object AggregateUtil { new FloatMinAggregate case DOUBLE => new DoubleMinAggregate + case DECIMAL => + new DecimalMinAggregate case BOOLEAN => new BooleanMinAggregate case sqlType: SqlTypeName => @@ -211,6 +217,8 @@ object AggregateUtil { new FloatMaxAggregate case DOUBLE => new DoubleMaxAggregate + case DECIMAL => + new DecimalMaxAggregate case BOOLEAN => new BooleanMaxAggregate case sqlType: SqlTypeName => diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala index e7246481a67..ce5bc818f1a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala @@ -18,8 +18,9 @@ package org.apache.flink.api.table.runtime.aggregate import com.google.common.math.LongMath -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row +import java.math.BigDecimal import java.math.BigInteger abstract class AvgAggregate[T] extends Aggregate[T] { @@ -251,3 +252,45 @@ class DoubleAvgAggregate extends FloatingAvgAggregate[Double] { } } } + +class DecimalAvgAggregate extends AvgAggregate[BigDecimal] { + + override def intermediateDataType = Array( + BasicTypeInfo.BIG_DEC_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, BigDecimal.ZERO) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + val input = value.asInstanceOf[BigDecimal] + partial.setField(partialSumIndex, input) + partial.setField(partialCountIndex, 1L) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.productElement(partialSumIndex).asInstanceOf[BigDecimal] + val partialCount = partial.productElement(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal] + val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long] + buffer.setField(partialSumIndex, partialSum.add(bufferSum)) + buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) + } + + override def evaluate(buffer: Row): BigDecimal = { + val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long] + if (bufferCount != 0) { + val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal] + bufferSum.divide(BigDecimal.valueOf(bufferCount)) + } else { + null.asInstanceOf[BigDecimal] + } + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala index b9b86d18f90..9267527e38e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala @@ -17,6 +17,8 @@ */ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal + import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row @@ -125,3 +127,45 @@ class BooleanMaxAggregate extends MaxAggregate[Boolean] { override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO) } + +class DecimalMaxAggregate extends Aggregate[BigDecimal] { + + protected var minIndex: Int = _ + + override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO) + + override def initiate(intermediate: Row): Unit = { + intermediate.setField(minIndex, null) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + partial.setField(minIndex, value) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal] + if (partialValue != null) { + val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal] + if (bufferValue != null) { + val min = if (partialValue.compareTo(bufferValue) > 0) partialValue else bufferValue + buffer.setField(minIndex, min) + } else { + buffer.setField(minIndex, partialValue) + } + } + } + + override def evaluate(buffer: Row): BigDecimal = { + buffer.productElement(minIndex).asInstanceOf[BigDecimal] + } + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + minIndex = aggOffset + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala index 5d656f47b09..7e2ebf40104 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row @@ -125,3 +126,45 @@ class BooleanMinAggregate extends MinAggregate[Boolean] { override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO) } + +class DecimalMinAggregate extends Aggregate[BigDecimal] { + + protected var minIndex: Int = _ + + override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO) + + override def initiate(intermediate: Row): Unit = { + intermediate.setField(minIndex, null) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + partial.setField(minIndex, value) + } + } + + override def merge(partial: Row, buffer: Row): Unit = { + val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal] + if (partialValue != null) { + val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal] + if (bufferValue != null) { + val min = if (partialValue.compareTo(bufferValue) < 0) partialValue else bufferValue + buffer.setField(minIndex, min) + } else { + buffer.setField(minIndex, partialValue) + } + } + } + + override def evaluate(buffer: Row): BigDecimal = { + buffer.productElement(minIndex).asInstanceOf[BigDecimal] + } + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + minIndex = aggOffset + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala index 6db6632c757..7ff23404c91 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row @@ -85,3 +86,45 @@ class FloatSumAggregate extends SumAggregate[Float] { class DoubleSumAggregate extends SumAggregate[Double] { override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) } + +class DecimalSumAggregate extends Aggregate[BigDecimal] { + + protected var sumIndex: Int = _ + + override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO) + + override def initiate(partial: Row): Unit = { + partial.setField(sumIndex, null) + } + + override def merge(partial1: Row, buffer: Row): Unit = { + val partialValue = partial1.productElement(sumIndex).asInstanceOf[BigDecimal] + if (partialValue != null) { + val bufferValue = buffer.productElement(sumIndex).asInstanceOf[BigDecimal] + if (bufferValue != null) { + buffer.setField(sumIndex, partialValue.add(bufferValue)) + } else { + buffer.setField(sumIndex, partialValue) + } + } + } + + override def evaluate(buffer: Row): BigDecimal = { + buffer.productElement(sumIndex).asInstanceOf[BigDecimal] + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + initiate(partial) + } else { + val input = value.asInstanceOf[BigDecimal] + partial.setField(sumIndex, input) + } + } + + override def supportPartial: Boolean = true + + override def setAggOffsetInRow(aggOffset: Int): Unit = { + sumIndex = aggOffset + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala index 1da1d2cf0a0..45ee76409bb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala @@ -17,17 +17,37 @@ */ package org.apache.flink.api.table.typeutils +import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, STRING_TYPE_INFO} import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} import org.apache.flink.api.table.validate._ object TypeCheckUtils { - def assertNumericExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = { - if (dataType.isInstanceOf[NumericTypeInfo[_]]) { + def isNumeric(dataType: TypeInformation[_]): Boolean = dataType match { + case _: NumericTypeInfo[_] => true + case BIG_DEC_TYPE_INFO => true + case _ => false + } + + def isString(dataType: TypeInformation[_]): Boolean = dataType == STRING_TYPE_INFO + + def isBoolean(dataType: TypeInformation[_]): Boolean = dataType == BOOLEAN_TYPE_INFO + + def isDecimal(dataType: TypeInformation[_]): Boolean = dataType == BIG_DEC_TYPE_INFO + + def isComparable(dataType: TypeInformation[_]): Boolean = + classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass) + + def assertNumericExpr( + dataType: TypeInformation[_], + caller: String) + : ExprValidationResult = dataType match { + case _: NumericTypeInfo[_] => ValidationSuccess - } else { + case BIG_DEC_TYPE_INFO => + ValidationSuccess + case _ => ValidationFailure(s"$caller requires numeric types, get $dataType here") - } } def assertOrderableExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCoercion.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCoercion.scala index 218996da9a3..baa12d96646 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCoercion.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCoercion.scala @@ -37,11 +37,14 @@ object TypeCoercion { def widerTypeOf(tp1: TypeInformation[_], tp2: TypeInformation[_]): Option[TypeInformation[_]] = { (tp1, tp2) match { - case (tp1, tp2) if tp1 == tp2 => Some(tp1) + case (ti1, ti2) if ti1 == ti2 => Some(ti1) case (_, STRING_TYPE_INFO) => Some(STRING_TYPE_INFO) case (STRING_TYPE_INFO, _) => Some(STRING_TYPE_INFO) + case (_, BIG_DEC_TYPE_INFO) => Some(BIG_DEC_TYPE_INFO) + case (BIG_DEC_TYPE_INFO, _) => Some(BIG_DEC_TYPE_INFO) + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2) Some(numericWideningPrecedence(higherIndex)) @@ -56,6 +59,8 @@ object TypeCoercion { def canSafelyCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match { case (_, STRING_TYPE_INFO) => true + case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) { true @@ -71,21 +76,24 @@ object TypeCoercion { * Note: This may lose information during the cast. */ def canCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match { - case (from, to) if from == to => true + case (fromTp, toTp) if fromTp == toTp => true case (_, STRING_TYPE_INFO) => true - case (_, DATE_TYPE_INFO) => false // Date type not supported yet. - case (_, VOID_TYPE_INFO) => false // Void type not supported case (_, CHAR_TYPE_INFO) => false // Character type not supported. case (STRING_TYPE_INFO, _: NumericTypeInfo[_]) => true case (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) => true + case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) => true case (BOOLEAN_TYPE_INFO, _: NumericTypeInfo[_]) => true + case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) => true case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) => true + case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) => true case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => true + case (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_]) => true + case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true case _ => false } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala index cf23434c052..a5b248487ec 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala @@ -56,6 +56,7 @@ object TypeConverter { case STRING_TYPE_INFO => VARCHAR case STRING_VALUE_TYPE_INFO => VARCHAR case DATE_TYPE_INFO => DATE + case BIG_DEC_TYPE_INFO => DECIMAL case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO => throw new TableException("Character type is not supported.") @@ -74,6 +75,7 @@ object TypeConverter { case DOUBLE => DOUBLE_TYPE_INFO case VARCHAR | CHAR => STRING_TYPE_INFO case DATE => DATE_TYPE_INFO + case DECIMAL => BIG_DEC_TYPE_INFO case NULL => throw new TableException("Type NULL is not supported. " + diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala index 01ec94aa93a..2dce751b3dd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala @@ -98,7 +98,9 @@ class AggregationsITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) - val sqlQuery = "SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7)" + + val sqlQuery = + "SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7), " + + " sum(CAST(_6 AS DECIMAL))" + "FROM MyTable" val ds = env.fromElements( @@ -108,7 +110,7 @@ class AggregationsITCase( val result = tEnv.sql(sqlQuery) - val expected = "1,1,1,1,1.5,1.5,2" + val expected = "1,1,1,1,1.5,1.5,2,3.0" val results = result.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/ExpressionsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/ExpressionsITCase.scala index 9a0a035d01a..0b51175c84e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/ExpressionsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/ExpressionsITCase.scala @@ -157,6 +157,23 @@ class ExpressionsITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testDecimalLiteral(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env + .fromElements( + (BigDecimal("78.454654654654654").bigDecimal, BigDecimal("4E+9999").bigDecimal) + ) + .toTable(tEnv, 'a, 'b) + .select('a, 'b, BigDecimal("11.2"), BigDecimal("11.2").bigDecimal) + + val expected = "78.454654654654654,4E+9999,11.2,11.2" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + // Date literals not yet supported @Ignore @Test diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/DecimalTypeTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/DecimalTypeTest.scala new file mode 100644 index 00000000000..082bdec060b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/DecimalTypeTest.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.expressions.utils.ExpressionTestBase +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.junit.Test + +class DecimalTypeTest extends ExpressionTestBase { + + @Test + def testDecimalLiterals(): Unit = { + // implicit double + testAllApis( + 11.2, + "11.2", + "11.2", + "11.2") + + // implicit double + testAllApis( + 0.7623533651719233, + "0.7623533651719233", + "0.7623533651719233", + "0.7623533651719233") + + // explicit decimal (with precision of 19) + testAllApis( + BigDecimal("1234567891234567891"), + "1234567891234567891p", + "1234567891234567891", + "1234567891234567891") + + // explicit decimal (high precision, not SQL compliant) + testTableApi( + BigDecimal("123456789123456789123456789"), + "123456789123456789123456789p", + "123456789123456789123456789") + + // explicit decimal (high precision, not SQL compliant) + testTableApi( + BigDecimal("12.3456789123456789123456789"), + "12.3456789123456789123456789p", + "12.3456789123456789123456789") + } + + @Test + def testDecimalBorders(): Unit = { + testAllApis( + Double.MaxValue, + Double.MaxValue.toString, + Double.MaxValue.toString, + Double.MaxValue.toString) + + testAllApis( + Double.MinValue, + Double.MinValue.toString, + Double.MinValue.toString, + Double.MinValue.toString) + + testAllApis( + Double.MinValue.cast(FLOAT_TYPE_INFO), + s"${Double.MinValue}.cast(FLOAT)", + s"CAST(${Double.MinValue} AS FLOAT)", + Float.NegativeInfinity.toString) + + testAllApis( + Byte.MinValue.cast(BYTE_TYPE_INFO), + s"(${Byte.MinValue}).cast(BYTE)", + s"CAST(${Byte.MinValue} AS TINYINT)", + Byte.MinValue.toString) + + testAllApis( + Byte.MinValue.cast(BYTE_TYPE_INFO) - 1.cast(BYTE_TYPE_INFO), + s"(${Byte.MinValue}).cast(BYTE) - (1).cast(BYTE)", + s"CAST(${Byte.MinValue} AS TINYINT) - CAST(1 AS TINYINT)", + Byte.MaxValue.toString) + + testAllApis( + Short.MinValue.cast(SHORT_TYPE_INFO), + s"(${Short.MinValue}).cast(SHORT)", + s"CAST(${Short.MinValue} AS SMALLINT)", + Short.MinValue.toString) + + testAllApis( + Int.MinValue.cast(INT_TYPE_INFO) - 1, + s"(${Int.MinValue}).cast(INT) - 1", + s"CAST(${Int.MinValue} AS INT) - 1", + Int.MaxValue.toString) + + testAllApis( + Long.MinValue.cast(LONG_TYPE_INFO), + s"(${Long.MinValue}L).cast(LONG)", + s"CAST(${Long.MinValue} AS BIGINT)", + Long.MinValue.toString) + } + + @Test + def testDecimalCasting(): Unit = { + // from String + testTableApi( + "123456789123456789123456789".cast(BIG_DEC_TYPE_INFO), + "'123456789123456789123456789'.cast(DECIMAL)", + "123456789123456789123456789") + + // from double + testAllApis( + 'f3.cast(BIG_DEC_TYPE_INFO), + "f3.cast(DECIMAL)", + "CAST(f3 AS DECIMAL)", + "4.2") + + // to double + testAllApis( + 'f0.cast(DOUBLE_TYPE_INFO), + "f0.cast(DOUBLE)", + "CAST(f0 AS DOUBLE)", + "1.2345678912345679E8") + + // to int + testAllApis( + 'f4.cast(INT_TYPE_INFO), + "f4.cast(INT)", + "CAST(f4 AS INT)", + "123456789") + + // to long + testAllApis( + 'f4.cast(LONG_TYPE_INFO), + "f4.cast(LONG)", + "CAST(f4 AS BIGINT)", + "123456789") + + // to boolean (not SQL compliant) + testTableApi( + 'f1.cast(BOOLEAN_TYPE_INFO), + "f1.cast(BOOL)", + "true") + + testTableApi( + 'f5.cast(BOOLEAN_TYPE_INFO), + "f5.cast(BOOL)", + "false") + + testTableApi( + BigDecimal("123456789.123456789123456789").cast(DOUBLE_TYPE_INFO), + "(123456789.123456789123456789p).cast(DOUBLE)", + "1.2345678912345679E8") + } + + @Test + def testDecimalArithmetic(): Unit = { + // implicit cast to decimal + testAllApis( + 'f1 + 12, + "f1 + 12", + "f1 + 12", + "123456789123456789123456801") + + // implicit cast to decimal + testAllApis( + Literal(12) + 'f1, + "12 + f1", + "12 + f1", + "123456789123456789123456801") + + // implicit cast to decimal + testAllApis( + 'f1 + 12.3, + "f1 + 12.3", + "f1 + 12.3", + "123456789123456789123456801.3") + + // implicit cast to decimal + testAllApis( + Literal(12.3) + 'f1, + "12.3 + f1", + "12.3 + f1", + "123456789123456789123456801.3") + + testAllApis( + 'f1 + 'f1, + "f1 + f1", + "f1 + f1", + "246913578246913578246913578") + + testAllApis( + 'f1 - 'f1, + "f1 - f1", + "f1 - f1", + "0") + + testAllApis( + 'f1 * 'f1, + "f1 * f1", + "f1 * f1", + "15241578780673678546105778281054720515622620750190521") + + testAllApis( + 'f1 / 'f1, + "f1 / f1", + "f1 / f1", + "1") + + testAllApis( + 'f1 % 'f1, + "f1 % f1", + "MOD(f1, f1)", + "0") + + testAllApis( + -'f0, + "-f0", + "-f0", + "-123456789.123456789123456789") + } + + @Test + def testDecimalComparison(): Unit = { + testAllApis( + 'f1 < 12, + "f1 < 12", + "f1 < 12", + "false") + + testAllApis( + 'f1 > 12, + "f1 > 12", + "f1 > 12", + "true") + + testAllApis( + 'f1 === 12, + "f1 === 12", + "f1 = 12", + "false") + + testAllApis( + 'f5 === 0, + "f5 === 0", + "f5 = 0", + "true") + + testAllApis( + 'f1 === BigDecimal("123456789123456789123456789"), + "f1 === 123456789123456789123456789p", + "f1 = CAST('123456789123456789123456789' AS DECIMAL)", + "true") + + testAllApis( + 'f1 !== BigDecimal("123456789123456789123456789"), + "f1 !== 123456789123456789123456789p", + "f1 <> CAST('123456789123456789123456789' AS DECIMAL)", + "false") + + testAllApis( + 'f4 < 'f0, + "f4 < f0", + "f4 < f0", + "true") + + // TODO add all tests if FLINK-4070 is fixed + testSqlApi( + "12 < f1", + "true") + } + + // ---------------------------------------------------------------------------------------------- + + def testData = { + val testData = new Row(6) + testData.setField(0, BigDecimal("123456789.123456789123456789").bigDecimal) + testData.setField(1, BigDecimal("123456789123456789123456789").bigDecimal) + testData.setField(2, 42) + testData.setField(3, 4.2) + testData.setField(4, BigDecimal("123456789").bigDecimal) + testData.setField(5, BigDecimal("0.000").bigDecimal) + testData + } + + def typeInfo = { + new RowTypeInfo(Seq( + BIG_DEC_TYPE_INFO, + BIG_DEC_TYPE_INFO, + INT_TYPE_INFO, + DOUBLE_TYPE_INFO, + BIG_DEC_TYPE_INFO, + BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]] + } + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ScalarFunctionsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ScalarFunctionsTest.scala index a3b67f511ec..640d6047877 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ScalarFunctionsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ScalarFunctionsTest.scala @@ -20,15 +20,13 @@ package org.apache.flink.api.table.expressions import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.table.expressions.utils.ExpressionEvaluator import org.apache.flink.api.scala.table._ import org.apache.flink.api.table.Row -import org.apache.flink.api.table.expressions.{Expression, ExpressionParser} +import org.apache.flink.api.table.expressions.utils.ExpressionTestBase import org.apache.flink.api.table.typeutils.RowTypeInfo -import org.junit.Assert.assertEquals import org.junit.Test -class ScalarFunctionsTest { +class ScalarFunctionsTest extends ExpressionTestBase { // ---------------------------------------------------------------------------------------------- // String functions @@ -36,19 +34,19 @@ class ScalarFunctionsTest { @Test def testSubstring(): Unit = { - testFunction( + testAllApis( 'f0.substring(2), "f0.substring(2)", "SUBSTRING(f0, 2)", "his is a test String.") - testFunction( + testAllApis( 'f0.substring(2, 5), "f0.substring(2, 5)", "SUBSTRING(f0, 2, 5)", "his i") - testFunction( + testAllApis( 'f0.substring(1, 'f7), "f0.substring(1, f7)", "SUBSTRING(f0, 1, f7)", @@ -57,25 +55,25 @@ class ScalarFunctionsTest { @Test def testTrim(): Unit = { - testFunction( + testAllApis( 'f8.trim(), "f8.trim()", "TRIM(f8)", "This is a test String.") - testFunction( + testAllApis( 'f8.trim(removeLeading = true, removeTrailing = true, " "), "trim(f8)", "TRIM(f8)", "This is a test String.") - testFunction( + testAllApis( 'f8.trim(removeLeading = false, removeTrailing = true, " "), "f8.trim(TRAILING, ' ')", "TRIM(TRAILING FROM f8)", " This is a test String.") - testFunction( + testAllApis( 'f0.trim(removeLeading = true, removeTrailing = true, "."), "trim(BOTH, '.', f0)", "TRIM(BOTH '.' FROM f0)", @@ -84,13 +82,13 @@ class ScalarFunctionsTest { @Test def testCharLength(): Unit = { - testFunction( + testAllApis( 'f0.charLength(), "f0.charLength()", "CHAR_LENGTH(f0)", "22") - testFunction( + testAllApis( 'f0.charLength(), "charLength(f0)", "CHARACTER_LENGTH(f0)", @@ -99,7 +97,7 @@ class ScalarFunctionsTest { @Test def testUpperCase(): Unit = { - testFunction( + testAllApis( 'f0.upperCase(), "f0.upperCase()", "UPPER(f0)", @@ -108,7 +106,7 @@ class ScalarFunctionsTest { @Test def testLowerCase(): Unit = { - testFunction( + testAllApis( 'f0.lowerCase(), "f0.lowerCase()", "LOWER(f0)", @@ -117,7 +115,7 @@ class ScalarFunctionsTest { @Test def testInitCap(): Unit = { - testFunction( + testAllApis( 'f0.initCap(), "f0.initCap()", "INITCAP(f0)", @@ -126,7 +124,7 @@ class ScalarFunctionsTest { @Test def testConcat(): Unit = { - testFunction( + testAllApis( 'f0 + 'f0, "f0 + f0", "f0||f0", @@ -135,13 +133,13 @@ class ScalarFunctionsTest { @Test def testLike(): Unit = { - testFunction( + testAllApis( 'f0.like("Th_s%"), "f0.like('Th_s%')", "f0 LIKE 'Th_s%'", "true") - testFunction( + testAllApis( 'f0.like("%is a%"), "f0.like('%is a%')", "f0 LIKE '%is a%'", @@ -150,13 +148,13 @@ class ScalarFunctionsTest { @Test def testNotLike(): Unit = { - testFunction( + testAllApis( !'f0.like("Th_s%"), "!f0.like('Th_s%')", "f0 NOT LIKE 'Th_s%'", "false") - testFunction( + testAllApis( !'f0.like("%is a%"), "!f0.like('%is a%')", "f0 NOT LIKE '%is a%'", @@ -165,13 +163,13 @@ class ScalarFunctionsTest { @Test def testSimilar(): Unit = { - testFunction( + testAllApis( 'f0.similar("_*"), "f0.similar('_*')", "f0 SIMILAR TO '_*'", "true") - testFunction( + testAllApis( 'f0.similar("This (is)? a (test)+ Strin_*"), "f0.similar('This (is)? a (test)+ Strin_*')", "f0 SIMILAR TO 'This (is)? a (test)+ Strin_*'", @@ -180,13 +178,13 @@ class ScalarFunctionsTest { @Test def testNotSimilar(): Unit = { - testFunction( + testAllApis( !'f0.similar("_*"), "!f0.similar('_*')", "f0 NOT SIMILAR TO '_*'", "false") - testFunction( + testAllApis( !'f0.similar("This (is)? a (test)+ Strin_*"), "!f0.similar('This (is)? a (test)+ Strin_*')", "f0 NOT SIMILAR TO 'This (is)? a (test)+ Strin_*'", @@ -195,19 +193,19 @@ class ScalarFunctionsTest { @Test def testMod(): Unit = { - testFunction( + testAllApis( 'f4.mod('f7), "f4.mod(f7)", "MOD(f4, f7)", "2") - testFunction( + testAllApis( 'f4.mod(3), "mod(f4, 3)", "MOD(f4, 3)", "2") - testFunction( + testAllApis( 'f4 % 3, "mod(44, 3)", "MOD(44, 3)", @@ -217,37 +215,43 @@ class ScalarFunctionsTest { @Test def testExp(): Unit = { - testFunction( + testAllApis( 'f2.exp(), "f2.exp()", "EXP(f2)", math.exp(42.toByte).toString) - testFunction( + testAllApis( 'f3.exp(), "f3.exp()", "EXP(f3)", math.exp(43.toShort).toString) - testFunction( + testAllApis( 'f4.exp(), "f4.exp()", "EXP(f4)", math.exp(44.toLong).toString) - testFunction( + testAllApis( 'f5.exp(), "f5.exp()", "EXP(f5)", math.exp(4.5.toFloat).toString) - testFunction( + testAllApis( 'f6.exp(), "f6.exp()", "EXP(f6)", math.exp(4.6).toString) - testFunction( + testAllApis( + 'f7.exp(), + "exp(3)", + "EXP(3)", + math.exp(3).toString) + + testAllApis( 'f7.exp(), "exp(3)", "EXP(3)", @@ -256,31 +260,31 @@ class ScalarFunctionsTest { @Test def testLog10(): Unit = { - testFunction( + testAllApis( 'f2.log10(), "f2.log10()", "LOG10(f2)", math.log10(42.toByte).toString) - testFunction( + testAllApis( 'f3.log10(), "f3.log10()", "LOG10(f3)", math.log10(43.toShort).toString) - testFunction( + testAllApis( 'f4.log10(), "f4.log10()", "LOG10(f4)", math.log10(44.toLong).toString) - testFunction( + testAllApis( 'f5.log10(), "f5.log10()", "LOG10(f5)", math.log10(4.5.toFloat).toString) - testFunction( + testAllApis( 'f6.log10(), "f6.log10()", "LOG10(f6)", @@ -289,19 +293,25 @@ class ScalarFunctionsTest { @Test def testPower(): Unit = { - testFunction( + testAllApis( 'f2.power('f7), "f2.power(f7)", "POWER(f2, f7)", math.pow(42.toByte, 3).toString) - testFunction( + testAllApis( 'f3.power('f6), "f3.power(f6)", "POWER(f3, f6)", math.pow(43.toShort, 4.6D).toString) - testFunction( + testAllApis( + 'f4.power('f5), + "f4.power(f5)", + "POWER(f4, f5)", + math.pow(44.toLong, 4.5.toFloat).toString) + + testAllApis( 'f4.power('f5), "f4.power(f5)", "POWER(f4, f5)", @@ -310,31 +320,31 @@ class ScalarFunctionsTest { @Test def testLn(): Unit = { - testFunction( + testAllApis( 'f2.ln(), "f2.ln()", "LN(f2)", math.log(42.toByte).toString) - testFunction( + testAllApis( 'f3.ln(), "f3.ln()", "LN(f3)", math.log(43.toShort).toString) - testFunction( + testAllApis( 'f4.ln(), "f4.ln()", "LN(f4)", math.log(44.toLong).toString) - testFunction( + testAllApis( 'f5.ln(), "f5.ln()", "LN(f5)", math.log(4.5.toFloat).toString) - testFunction( + testAllApis( 'f6.ln(), "f6.ln()", "LN(f6)", @@ -343,102 +353,116 @@ class ScalarFunctionsTest { @Test def testAbs(): Unit = { - testFunction( + testAllApis( 'f2.abs(), "f2.abs()", "ABS(f2)", "42") - testFunction( + testAllApis( 'f3.abs(), "f3.abs()", "ABS(f3)", "43") - testFunction( + testAllApis( 'f4.abs(), "f4.abs()", "ABS(f4)", "44") - testFunction( + testAllApis( 'f5.abs(), "f5.abs()", "ABS(f5)", "4.5") - testFunction( + testAllApis( 'f6.abs(), "f6.abs()", "ABS(f6)", "4.6") - testFunction( + testAllApis( 'f9.abs(), "f9.abs()", "ABS(f9)", "42") - testFunction( + testAllApis( 'f10.abs(), "f10.abs()", "ABS(f10)", "43") - testFunction( + testAllApis( 'f11.abs(), "f11.abs()", "ABS(f11)", "44") - testFunction( + testAllApis( 'f12.abs(), "f12.abs()", "ABS(f12)", "4.5") - testFunction( + testAllApis( 'f13.abs(), "f13.abs()", "ABS(f13)", "4.6") + + testAllApis( + 'f15.abs(), + "f15.abs()", + "ABS(f15)", + "1231.1231231321321321111") } @Test def testArithmeticFloorCeil(): Unit = { - testFunction( + testAllApis( 'f5.floor(), "f5.floor()", "FLOOR(f5)", "4.0") - testFunction( + testAllApis( 'f5.ceil(), "f5.ceil()", "CEIL(f5)", "5.0") - testFunction( + testAllApis( 'f3.floor(), "f3.floor()", "FLOOR(f3)", "43") - testFunction( + testAllApis( 'f3.ceil(), "f3.ceil()", "CEIL(f3)", "43") + + testAllApis( + 'f15.floor(), + "f15.floor()", + "FLOOR(f15)", + "-1232") + + testAllApis( + 'f15.ceil(), + "f15.ceil()", + "CEIL(f15)", + "-1231") } // ---------------------------------------------------------------------------------------------- - def testFunction( - expr: Expression, - exprString: String, - sqlExpr: String, - expected: String): Unit = { - val testData = new Row(15) + def testData = { + val testData = new Row(16) testData.setField(0, "This is a test String.") testData.setField(1, true) testData.setField(2, 42.toByte) @@ -454,8 +478,12 @@ class ScalarFunctionsTest { testData.setField(12, -4.5.toFloat) testData.setField(13, -4.6) testData.setField(14, -3) + testData.setField(15, BigDecimal("-1231.1231231321321321111").bigDecimal) + testData + } - val typeInfo = new RowTypeInfo(Seq( + def typeInfo = { + new RowTypeInfo(Seq( STRING_TYPE_INFO, BOOLEAN_TYPE_INFO, BYTE_TYPE_INFO, @@ -470,21 +498,7 @@ class ScalarFunctionsTest { LONG_TYPE_INFO, FLOAT_TYPE_INFO, DOUBLE_TYPE_INFO, - INT_TYPE_INFO)).asInstanceOf[TypeInformation[Any]] - - val exprResult = ExpressionEvaluator.evaluate(testData, typeInfo, expr) - assertEquals(expected, exprResult) - - val exprStringResult = ExpressionEvaluator.evaluate( - testData, - typeInfo, - ExpressionParser.parseExpression(exprString)) - assertEquals(expected, exprStringResult) - - val exprSqlResult = ExpressionEvaluator.evaluate(testData, typeInfo, sqlExpr) - assertEquals(expected, exprSqlResult) + INT_TYPE_INFO, + BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]] } - - - } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionEvaluator.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionEvaluator.scala deleted file mode 100644 index 0b5a2dedd42..00000000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionEvaluator.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.expressions.utils - -import org.apache.calcite.rel.logical.LogicalProject -import org.apache.calcite.rex.RexNode -import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR -import org.apache.calcite.tools.{Frameworks, RelBuilder} -import org.apache.flink.api.common.functions.{Function, MapFunction} -import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.{DataSet => JDataSet} -import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} -import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction} -import org.apache.flink.api.table.expressions.Expression -import org.apache.flink.api.table.runtime.FunctionCompiler -import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig, TableEnvironment} -import org.mockito.Mockito._ - -/** - * Utility to translate and evaluate an RexNode or Table API expression to a String. - */ -object ExpressionEvaluator { - - // TestCompiler that uses current class loader - class TestCompiler[T <: Function] extends FunctionCompiler[T] { - def compile(genFunc: GeneratedFunction[T]): Class[T] = - compile(getClass.getClassLoader, genFunc.name, genFunc.code) - } - - private def prepareTable( - typeInfo: TypeInformation[Any]): (String, RelBuilder, TableEnvironment) = { - - // create DataSetTable - val dataSetMock = mock(classOf[DataSet[Any]]) - val jDataSetMock = mock(classOf[JDataSet[Any]]) - when(dataSetMock.javaSet).thenReturn(jDataSetMock) - when(jDataSetMock.getType).thenReturn(typeInfo) - - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - - val tableName = "myTable" - tEnv.registerDataSet(tableName, dataSetMock) - - // prepare RelBuilder - val relBuilder = tEnv.getRelBuilder - relBuilder.scan(tableName) - - (tableName, relBuilder, tEnv) - } - - def evaluate(data: Any, typeInfo: TypeInformation[Any], sqlExpr: String): String = { - // create DataSetTable - val table = prepareTable(typeInfo) - - // create RelNode from SQL expression - val planner = Frameworks.getPlanner(table._3.getFrameworkConfig) - val parsed = planner.parse("SELECT " + sqlExpr + " FROM " + table._1) - val validated = planner.validate(parsed) - val converted = planner.rel(validated) - - val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0) - - evaluate(data, typeInfo, table._2, expr) - } - - def evaluate(data: Any, typeInfo: TypeInformation[Any], expr: Expression): String = { - val table = prepareTable(typeInfo) - val env = table._3 - val resolvedExpr = - env.asInstanceOf[BatchTableEnvironment].scan("myTable").select(expr). - getRelNode.asInstanceOf[LogicalProject].getChildExps.get(0) - evaluate(data, typeInfo, table._2, resolvedExpr) - } - - def evaluate( - data: Any, - typeInfo: TypeInformation[Any], - relBuilder: RelBuilder, - rexNode: RexNode): String = { - // generate code for Mapper - val config = new TableConfig() - val generator = new CodeGenerator(config, false, typeInfo) - val genExpr = generator.generateExpression(relBuilder.cast(rexNode, VARCHAR)) // cast to String - val bodyCode = - s""" - |${genExpr.code} - |return ${genExpr.resultTerm}; - |""".stripMargin - val genFunc = generator.generateFunction[MapFunction[Any, String]]( - "TestFunction", - classOf[MapFunction[Any, String]], - bodyCode, - STRING_TYPE_INFO.asInstanceOf[TypeInformation[Any]]) - - // compile and evaluate - val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc) - val mapper = clazz.newInstance() - mapper.map(data) - } - -} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala new file mode 100644 index 00000000000..4345dd82fbd --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions.utils + +import org.apache.calcite.rel.logical.LogicalProject +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.tools.{Frameworks, RelBuilder} +import org.apache.flink.api.common.functions.{Function, MapFunction} +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.{DataSet => JDataSet} +import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} +import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction} +import org.apache.flink.api.table.expressions.{Expression, ExpressionParser} +import org.apache.flink.api.table.runtime.FunctionCompiler +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.{BatchTableEnvironment, Row, TableConfig, TableEnvironment} +import org.junit.Assert._ +import org.junit.{After, Before} +import org.mockito.Mockito._ + +import scala.collection.mutable + +/** + * Base test class for expression tests. + */ +abstract class ExpressionTestBase { + + private val testExprs = mutable.LinkedHashSet[(RexNode, String)]() + + // setup test utils + private val tableName = "testTable" + private val context = prepareContext(typeInfo) + private val planner = Frameworks.getPlanner(context._2.getFrameworkConfig) + + private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = { + // create DataSetTable + val dataSetMock = mock(classOf[DataSet[Any]]) + val jDataSetMock = mock(classOf[JDataSet[Any]]) + when(dataSetMock.javaSet).thenReturn(jDataSetMock) + when(jDataSetMock.getType).thenReturn(typeInfo) + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerDataSet(tableName, dataSetMock) + + // prepare RelBuilder + val relBuilder = tEnv.getRelBuilder + relBuilder.scan(tableName) + + (relBuilder, tEnv) + } + + def testData: Any + + def typeInfo: TypeInformation[Any] + + @Before + def resetTestExprs() = { + testExprs.clear() + } + + @After + def evaluateExprs() = { + val relBuilder = context._1 + val config = new TableConfig() + val generator = new CodeGenerator(config, false, typeInfo) + + // cast expressions to String + val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, VARCHAR)).toSeq + + // generate code + val resultType = new RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO)) + val genExpr = generator.generateResultExpression( + resultType, + resultType.getFieldNames, + stringTestExprs) + + val bodyCode = + s""" + |${genExpr.code} + |return ${genExpr.resultTerm}; + |""".stripMargin + + val genFunc = generator.generateFunction[MapFunction[Any, String]]( + "TestFunction", + classOf[MapFunction[Any, String]], + bodyCode, + resultType.asInstanceOf[TypeInformation[Any]]) + + // compile and evaluate + val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc) + val mapper = clazz.newInstance() + val result = mapper.map(testData).asInstanceOf[Row] + + // compare + testExprs + .zipWithIndex + .foreach { + case ((expr, expected), index) => + assertEquals(s"Wrong result for: $expr", expected, result.productElement(index)) + } + } + + private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = { + // create RelNode from SQL expression + val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName") + val validated = planner.validate(parsed) + val converted = planner.rel(validated) + + // extract RexNode + val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0) + testExprs.add((expr, expected)) + + planner.close() + } + + private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): Unit = { + val env = context._2 + val expr = env + .asInstanceOf[BatchTableEnvironment] + .scan(tableName) + .select(tableApiExpr) + .getRelNode + .asInstanceOf[LogicalProject] + .getChildExps + .get(0) + testExprs.add((expr, expected)) + } + + private def addTableApiTestExpr(tableApiString: String, expected: String): Unit = { + addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), expected) + } + + def testAllApis( + expr: Expression, + exprString: String, + sqlExpr: String, + expected: String) + : Unit = { + addTableApiTestExpr(expr, expected) + addTableApiTestExpr(exprString, expected) + addSqlTestExpr(sqlExpr, expected) + } + + def testTableApi( + expr: Expression, + exprString: String, + expected: String) + : Unit = { + addTableApiTestExpr(expr, expected) + addTableApiTestExpr(exprString, expected) + } + + def testSqlApi( + sqlExpr: String, + expected: String) + : Unit = { + addSqlTestExpr(sqlExpr, expected) + } + + // ---------------------------------------------------------------------------------------------- + + // TestCompiler that uses current class loader + class TestCompiler[T <: Function] extends FunctionCompiler[T] { + def compile(genFunc: GeneratedFunction[T]): Class[T] = + compile(getClass.getClassLoader, genFunc.name, genFunc.code) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala index 78d5f8c0bac..54911a55742 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala @@ -18,6 +18,7 @@ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal import org.apache.flink.api.table.Row import org.junit.Test import org.junit.Assert.assertEquals @@ -63,8 +64,13 @@ abstract class AggregateTestBase[T] { finalAgg(rows) } - assertEquals(expected, result) - + (expected, result) match { + case (e: BigDecimal, r: BigDecimal) => + // BigDecimal.equals() value and scale but we are only interested in value. + assert(e.compareTo(r) == 0) + case _ => + assertEquals(expected, result) + } } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala index 48dc313f6dd..23b305481c6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala @@ -18,6 +18,8 @@ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal + abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { private val numeric: Numeric[T] = implicitly[Numeric[T]] @@ -122,3 +124,31 @@ class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] { override def aggregator = new DoubleAvgAggregate() } + +class DecimalAvgAggregateTest extends AggregateTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("987654321000000"), + new BigDecimal("-0.000000000012345"), + null, + new BigDecimal("0.000000000012345"), + new BigDecimal("-987654321000000"), + null, + new BigDecimal("0") + ), + Seq( + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + BigDecimal.ZERO, + null + ) + + override def aggregator: Aggregate[BigDecimal] = new DecimalAvgAggregate() +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala index 97385ae095f..aea3318dbd6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala @@ -18,6 +18,8 @@ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal + abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { private val numeric: Numeric[T] = implicitly[Numeric[T]] @@ -141,3 +143,35 @@ class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] { override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate() } + +class DecimalMaxAggregateTest extends AggregateTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("1000.000001"), + new BigDecimal("-1"), + new BigDecimal("-999.998999"), + null, + new BigDecimal("0"), + new BigDecimal("-999.999"), + null, + new BigDecimal("999.999") + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("1000.000001"), + null + ) + + override def aggregator: Aggregate[BigDecimal] = new DecimalMaxAggregate() + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala index cd77c1076c2..f007d02cfdd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala @@ -18,6 +18,8 @@ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal + abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { private val numeric: Numeric[T] = implicitly[Numeric[T]] @@ -141,3 +143,35 @@ class BooleanMinAggregateTest extends AggregateTestBase[Boolean] { override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate() } + +class DecimalMinAggregateTest extends AggregateTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("1000"), + new BigDecimal("-1"), + new BigDecimal("-999.998999"), + null, + new BigDecimal("0"), + new BigDecimal("-999.999"), + null, + new BigDecimal("999.999") + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("-999.999"), + null + ) + + override def aggregator: Aggregate[BigDecimal] = new DecimalMinAggregate() + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala index fb6fc39f4ff..7e4e47bee29 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala @@ -18,6 +18,8 @@ package org.apache.flink.api.table.runtime.aggregate +import java.math.BigDecimal + abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { private val numeric: Numeric[T] = implicitly[Numeric[T]] @@ -97,3 +99,39 @@ class DoubleSumAggregateTest extends SumAggregateTestBase[Double] { override def aggregator: Aggregate[Double] = new DoubleSumAggregate } + +class DecimalSumAggregateTest extends AggregateTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("2"), + new BigDecimal("3"), + null, + new BigDecimal("0"), + new BigDecimal("-1000"), + new BigDecimal("0.000000000002"), + new BigDecimal("1000"), + new BigDecimal("-0.000000000001"), + new BigDecimal("999.999"), + null, + new BigDecimal("4"), + new BigDecimal("-999.999"), + null + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("10.000000000001"), + null + ) + + override def aggregator: Aggregate[BigDecimal] = new DecimalSumAggregate() +} -- GitLab