提交 37defbb4 编写于 作者: T twalthr

[FLINK-3859] [table] Add BigDecimal/BigInteger support to Table API

This closes #2088.
上级 b6e6b818
...@@ -752,7 +752,7 @@ suffixed = cast | as | aggregation | nullCheck | evaluate | functionCall ; ...@@ -752,7 +752,7 @@ suffixed = cast | as | aggregation | nullCheck | evaluate | functionCall ;
cast = composite , ".cast(" , dataType , ")" ; cast = composite , ".cast(" , dataType , ")" ;
dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" ; dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" | "DECIMAL";
as = composite , ".as(" , fieldReference , ")" ; as = composite , ".as(" , fieldReference , ")" ;
...@@ -773,6 +773,8 @@ nullLiteral = "Null(" , dataType , ")" ; ...@@ -773,6 +773,8 @@ nullLiteral = "Null(" , dataType , ")" ;
Here, `literal` is a valid Java literal, `fieldReference` specifies a column in the data, and `functionIdentifier` specifies a supported scalar function. The Here, `literal` is a valid Java literal, `fieldReference` specifies a column in the data, and `functionIdentifier` specifies a supported scalar function. The
column names and function names follow Java identifier syntax. Expressions specified as Strings can also use prefix notation instead of suffix notation to call operators and functions. column names and function names follow Java identifier syntax. Expressions specified as Strings can also use prefix notation instead of suffix notation to call operators and functions.
If working with exact numeric values or large decimals is required, the Table API also supports Java's BigDecimal type. In the Scala Table API decimals can be defined by `BigDecimal("123456")` and in Java by appending a "p" for precise e.g. `123456p`.
{% top %} {% top %}
......
...@@ -220,6 +220,14 @@ trait ImplicitExpressionConversions { ...@@ -220,6 +220,14 @@ trait ImplicitExpressionConversions {
def expr = Literal(l) def expr = Literal(l)
} }
implicit class LiteralByteExpression(b: Byte) extends ImplicitExpressionOperations {
def expr = Literal(b)
}
implicit class LiteralShortExpression(s: Short) extends ImplicitExpressionOperations {
def expr = Literal(s)
}
implicit class LiteralIntExpression(i: Int) extends ImplicitExpressionOperations { implicit class LiteralIntExpression(i: Int) extends ImplicitExpressionOperations {
def expr = Literal(i) def expr = Literal(i)
} }
...@@ -240,11 +248,26 @@ trait ImplicitExpressionConversions { ...@@ -240,11 +248,26 @@ trait ImplicitExpressionConversions {
def expr = Literal(bool) def expr = Literal(bool)
} }
implicit class LiteralJavaDecimalExpression(javaDecimal: java.math.BigDecimal)
extends ImplicitExpressionOperations {
def expr = Literal(javaDecimal)
}
implicit class LiteralScalaDecimalExpression(scalaDecimal: scala.math.BigDecimal)
extends ImplicitExpressionOperations {
def expr = Literal(scalaDecimal.bigDecimal)
}
implicit def symbol2FieldExpression(sym: Symbol): Expression = UnresolvedFieldReference(sym.name) implicit def symbol2FieldExpression(sym: Symbol): Expression = UnresolvedFieldReference(sym.name)
implicit def byte2Literal(b: Byte): Expression = Literal(b)
implicit def short2Literal(s: Short): Expression = Literal(s)
implicit def int2Literal(i: Int): Expression = Literal(i) implicit def int2Literal(i: Int): Expression = Literal(i)
implicit def long2Literal(l: Long): Expression = Literal(l) implicit def long2Literal(l: Long): Expression = Literal(l)
implicit def double2Literal(d: Double): Expression = Literal(d) implicit def double2Literal(d: Double): Expression = Literal(d)
implicit def float2Literal(d: Float): Expression = Literal(d) implicit def float2Literal(d: Float): Expression = Literal(d)
implicit def string2Literal(str: String): Expression = Literal(str) implicit def string2Literal(str: String): Expression = Literal(str)
implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool) implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool)
implicit def javaDec2Literal(javaDec: java.math.BigDecimal): Expression = Literal(javaDec)
implicit def scalaDec2Literal(scalaDec: scala.math.BigDecimal): Expression =
Literal(scalaDec.bigDecimal)
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.table
import org.apache.calcite.rel.`type`.RelDataTypeSystemImpl
/**
* Custom type system for Flink.
*/
class FlinkTypeSystem extends RelDataTypeSystemImpl {
// we cannot use Int.MaxValue because of an overflow in Calcites type inference logic
// half should be enough for all use cases
override def getMaxNumericScale: Int = Int.MaxValue / 2
// we cannot use Int.MaxValue because of an overflow in Calcites type inference logic
// half should be enough for all use cases
override def getMaxNumericPrecision: Int = Int.MaxValue / 2
}
...@@ -69,6 +69,7 @@ abstract class TableEnvironment(val config: TableConfig) { ...@@ -69,6 +69,7 @@ abstract class TableEnvironment(val config: TableConfig) {
.defaultSchema(tables) .defaultSchema(tables)
.parserConfig(parserConfig) .parserConfig(parserConfig)
.costFactory(new DataSetCostFactory) .costFactory(new DataSetCostFactory)
.typeSystem(new FlinkTypeSystem)
.build .build
// the builder for Calcite RelNodes, Calcite's representation of a relational expression tree. // the builder for Calcite RelNodes, Calcite's representation of a relational expression tree.
......
...@@ -23,11 +23,11 @@ import java.util.concurrent.atomic.AtomicInteger ...@@ -23,11 +23,11 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._ import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} import org.apache.flink.api.common.typeinfo.{FractionalTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{TypeExtractor, PojoTypeInfo, TupleTypeInfo} import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo, TypeExtractor}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.table.typeutils.RowTypeInfo import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeCheckUtils}
object CodeGenUtils { object CodeGenUtils {
...@@ -97,11 +97,24 @@ object CodeGenUtils { ...@@ -97,11 +97,24 @@ object CodeGenUtils {
case _ => "null" case _ => "null"
} }
def requireNumeric(genExpr: GeneratedExpression) = genExpr.resultType match { def superPrimitive(typeInfo: TypeInformation[_]): String = typeInfo match {
case nti: NumericTypeInfo[_] => // ok case _: FractionalTypeInfo[_] => "double"
case _ => throw new CodeGenException("Numeric expression type expected.") case _ => "long"
} }
// ----------------------------------------------------------------------------------------------
def requireNumeric(genExpr: GeneratedExpression) =
if (!TypeCheckUtils.isNumeric(genExpr.resultType)) {
throw new CodeGenException("Numeric expression type expected, but was " +
s"'${genExpr.resultType}'.")
}
def requireComparable(genExpr: GeneratedExpression) =
if (!TypeCheckUtils.isComparable(genExpr.resultType)) {
throw new CodeGenException(s"Comparable type expected, but was '${genExpr.resultType}'.")
}
def requireString(genExpr: GeneratedExpression) = genExpr.resultType match { def requireString(genExpr: GeneratedExpression) = genExpr.resultType match {
case STRING_TYPE_INFO => // ok case STRING_TYPE_INFO => // ok
case _ => throw new CodeGenException("String expression type expected.") case _ => throw new CodeGenException("String expression type expected.")
...@@ -112,6 +125,8 @@ object CodeGenUtils { ...@@ -112,6 +125,8 @@ object CodeGenUtils {
case _ => throw new CodeGenException("Boolean expression type expected.") case _ => throw new CodeGenException("Boolean expression type expected.")
} }
// ----------------------------------------------------------------------------------------------
def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)
def isReference(typeInfo: TypeInformation[_]): Boolean = typeInfo match { def isReference(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
...@@ -126,27 +141,6 @@ object CodeGenUtils { ...@@ -126,27 +141,6 @@ object CodeGenUtils {
case _ => true case _ => true
} }
def isNumeric(genExpr: GeneratedExpression): Boolean = isNumeric(genExpr.resultType)
def isNumeric(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case nti: NumericTypeInfo[_] => true
case _ => false
}
def isString(genExpr: GeneratedExpression): Boolean = isString(genExpr.resultType)
def isString(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case STRING_TYPE_INFO => true
case _ => false
}
def isBoolean(genExpr: GeneratedExpression): Boolean = isBoolean(genExpr.resultType)
def isBoolean(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case BOOLEAN_TYPE_INFO => true
case _ => false
}
// ---------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------
sealed abstract class FieldAccessor sealed abstract class FieldAccessor
......
...@@ -18,10 +18,12 @@ ...@@ -18,10 +18,12 @@
package org.apache.flink.api.table.codegen package org.apache.flink.api.table.codegen
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.rex._ import org.apache.calcite.rex._
import org.apache.calcite.sql.{SqlLiteral, SqlOperator}
import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.{SqlLiteral, SqlOperator}
import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction} import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.common.typeutils.CompositeType
...@@ -32,8 +34,9 @@ import org.apache.flink.api.table.codegen.CodeGenUtils._ ...@@ -32,8 +34,9 @@ import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.Indenter.toISC import org.apache.flink.api.table.codegen.Indenter.toISC
import org.apache.flink.api.table.codegen.calls.ScalarFunctions import org.apache.flink.api.table.codegen.calls.ScalarFunctions
import org.apache.flink.api.table.codegen.calls.ScalarOperators._ import org.apache.flink.api.table.codegen.calls.ScalarOperators._
import org.apache.flink.api.table.typeutils.{TypeConverter, RowTypeInfo} import org.apache.flink.api.table.typeutils.RowTypeInfo
import TypeConverter.sqlTypeToTypeInfo import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString}
import org.apache.flink.api.table.typeutils.TypeConverter.sqlTypeToTypeInfo
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.collection.mutable import scala.collection.mutable
...@@ -542,7 +545,7 @@ class CodeGenerator( ...@@ -542,7 +545,7 @@ class CodeGenerator(
case BOOLEAN => case BOOLEAN =>
generateNonNullLiteral(resultType, literal.getValue3.toString) generateNonNullLiteral(resultType, literal.getValue3.toString)
case TINYINT => case TINYINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidByte) { if (decimal.isValidByte) {
generateNonNullLiteral(resultType, decimal.byteValue().toString) generateNonNullLiteral(resultType, decimal.byteValue().toString)
} }
...@@ -550,7 +553,7 @@ class CodeGenerator( ...@@ -550,7 +553,7 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to byte.") throw new CodeGenException("Decimal can not be converted to byte.")
} }
case SMALLINT => case SMALLINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidShort) { if (decimal.isValidShort) {
generateNonNullLiteral(resultType, decimal.shortValue().toString) generateNonNullLiteral(resultType, decimal.shortValue().toString)
} }
...@@ -558,7 +561,7 @@ class CodeGenerator( ...@@ -558,7 +561,7 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to short.") throw new CodeGenException("Decimal can not be converted to short.")
} }
case INTEGER => case INTEGER =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidInt) { if (decimal.isValidInt) {
generateNonNullLiteral(resultType, decimal.intValue().toString) generateNonNullLiteral(resultType, decimal.intValue().toString)
} }
...@@ -566,29 +569,36 @@ class CodeGenerator( ...@@ -566,29 +569,36 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to integer.") throw new CodeGenException("Decimal can not be converted to integer.")
} }
case BIGINT => case BIGINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidLong) { if (decimal.isValidLong) {
generateNonNullLiteral(resultType, decimal.longValue().toString) generateNonNullLiteral(resultType, decimal.longValue().toString + "L")
} }
else { else {
throw new CodeGenException("Decimal can not be converted to long.") throw new CodeGenException("Decimal can not be converted to long.")
} }
case FLOAT => case FLOAT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val floatValue = value.asInstanceOf[JBigDecimal].floatValue()
if (decimal.isValidFloat) { floatValue match {
generateNonNullLiteral(resultType, decimal.floatValue().toString + "f") case Float.NaN => generateNonNullLiteral(resultType, "java.lang.Float.NaN")
} case Float.NegativeInfinity =>
else { generateNonNullLiteral(resultType, "java.lang.Float.NEGATIVE_INFINITY")
throw new CodeGenException("Decimal can not be converted to float.") case Float.PositiveInfinity =>
generateNonNullLiteral(resultType, "java.lang.Float.POSITIVE_INFINITY")
case _ => generateNonNullLiteral(resultType, floatValue.toString + "f")
} }
case DOUBLE => case DOUBLE =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) val doubleValue = value.asInstanceOf[JBigDecimal].doubleValue()
if (decimal.isValidDouble) { doubleValue match {
generateNonNullLiteral(resultType, decimal.doubleValue().toString) case Double.NaN => generateNonNullLiteral(resultType, "java.lang.Double.NaN")
} case Double.NegativeInfinity =>
else { generateNonNullLiteral(resultType, "java.lang.Double.NEGATIVE_INFINITY")
throw new CodeGenException("Decimal can not be converted to double.") case Double.PositiveInfinity =>
generateNonNullLiteral(resultType, "java.lang.Double.POSITIVE_INFINITY")
case _ => generateNonNullLiteral(resultType, doubleValue.toString + "d")
} }
case DECIMAL =>
val decimalField = addReusableDecimal(value.asInstanceOf[JBigDecimal])
generateNonNullLiteral(resultType, decimalField)
case VARCHAR | CHAR => case VARCHAR | CHAR =>
generateNonNullLiteral(resultType, "\"" + value.toString + "\"") generateNonNullLiteral(resultType, "\"" + value.toString + "\"")
case SYMBOL => case SYMBOL =>
...@@ -630,7 +640,7 @@ class CodeGenerator( ...@@ -630,7 +640,7 @@ class CodeGenerator(
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
requireString(left) requireString(left)
generateArithmeticOperator("+", nullCheck, resultType, left, right) generateStringConcatOperator(nullCheck, left, right)
case MINUS if isNumeric(resultType) => case MINUS if isNumeric(resultType) =>
val left = operands.head val left = operands.head
...@@ -674,37 +684,39 @@ class CodeGenerator( ...@@ -674,37 +684,39 @@ class CodeGenerator(
case EQUALS => case EQUALS =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right)
generateEquals(nullCheck, left, right) generateEquals(nullCheck, left, right)
case NOT_EQUALS => case NOT_EQUALS =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right)
generateNotEquals(nullCheck, left, right) generateNotEquals(nullCheck, left, right)
case GREATER_THAN => case GREATER_THAN =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right) requireComparable(left)
requireComparable(right)
generateComparison(">", nullCheck, left, right) generateComparison(">", nullCheck, left, right)
case GREATER_THAN_OR_EQUAL => case GREATER_THAN_OR_EQUAL =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right) requireComparable(left)
requireComparable(right)
generateComparison(">=", nullCheck, left, right) generateComparison(">=", nullCheck, left, right)
case LESS_THAN => case LESS_THAN =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right) requireComparable(left)
requireComparable(right)
generateComparison("<", nullCheck, left, right) generateComparison("<", nullCheck, left, right)
case LESS_THAN_OR_EQUAL => case LESS_THAN_OR_EQUAL =>
val left = operands.head val left = operands.head
val right = operands(1) val right = operands(1)
checkNumericOrString(left, right) requireComparable(left)
requireComparable(right)
generateComparison("<=", nullCheck, left, right) generateComparison("<=", nullCheck, left, right)
case IS_NULL => case IS_NULL =>
...@@ -775,14 +787,6 @@ class CodeGenerator( ...@@ -775,14 +787,6 @@ class CodeGenerator(
// generator helping methods // generator helping methods
// ---------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------
def checkNumericOrString(left: GeneratedExpression, right: GeneratedExpression): Unit = {
if (isNumeric(left)) {
requireNumeric(right)
} else if (isString(left)) {
requireString(right)
}
}
private def generateInputAccess( private def generateInputAccess(
inputType: TypeInformation[Any], inputType: TypeInformation[Any],
inputTerm: String, inputTerm: String,
...@@ -1036,4 +1040,18 @@ class CodeGenerator( ...@@ -1036,4 +1040,18 @@ class CodeGenerator(
fieldTerm fieldTerm
} }
def addReusableDecimal(decimal: JBigDecimal): String = decimal match {
case JBigDecimal.ZERO => "java.math.BigDecimal.ZERO"
case JBigDecimal.ONE => "java.math.BigDecimal.ONE"
case JBigDecimal.TEN => "java.math.BigDecimal.TEN"
case _ =>
val fieldTerm = newName("decimal")
val fieldDecimal =
s"""
|transient java.math.BigDecimal $fieldTerm =
| new java.math.BigDecimal("${decimal.toString}");
|""".stripMargin
reusableMemberStatements.add(fieldDecimal)
fieldTerm
}
} }
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
*/ */
package org.apache.flink.api.table.codegen.calls package org.apache.flink.api.table.codegen.calls
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.linq4j.tree.Types import org.apache.calcite.linq4j.tree.Types
import org.apache.calcite.runtime.SqlFunctions import org.apache.calcite.runtime.SqlFunctions
...@@ -26,4 +28,5 @@ object BuiltInMethods { ...@@ -26,4 +28,5 @@ object BuiltInMethods {
val POWER = Types.lookupMethod(classOf[Math], "pow", classOf[Double], classOf[Double]) val POWER = Types.lookupMethod(classOf[Math], "pow", classOf[Double], classOf[Double])
val LN = Types.lookupMethod(classOf[Math], "log", classOf[Double]) val LN = Types.lookupMethod(classOf[Math], "log", classOf[Double])
val ABS = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[Double]) val ABS = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[Double])
val ABS_DEC = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[JBigDecimal])
} }
...@@ -20,7 +20,8 @@ package org.apache.flink.api.table.codegen.calls ...@@ -20,7 +20,8 @@ package org.apache.flink.api.table.codegen.calls
import java.lang.reflect.Method import java.lang.reflect.Method
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO} import org.apache.flink.api.common.typeinfo.BasicTypeInfo.
{DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO,BIG_DEC_TYPE_INFO}
import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression} import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression}
/** /**
...@@ -33,7 +34,7 @@ class FloorCeilCallGen(method: Method) extends MultiTypeMethodCallGen(method) { ...@@ -33,7 +34,7 @@ class FloorCeilCallGen(method: Method) extends MultiTypeMethodCallGen(method) {
operands: Seq[GeneratedExpression]) operands: Seq[GeneratedExpression])
: GeneratedExpression = { : GeneratedExpression = {
operands.head.resultType match { operands.head.resultType match {
case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO => case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO | BIG_DEC_TYPE_INFO =>
super.generate(codeGenerator, operands) super.generate(codeGenerator, operands)
case _ => case _ =>
operands.head // no floor/ceil necessary operands.head // no floor/ceil necessary
......
...@@ -142,16 +142,31 @@ object ScalarFunctions { ...@@ -142,16 +142,31 @@ object ScalarFunctions {
Seq(DOUBLE_TYPE_INFO), Seq(DOUBLE_TYPE_INFO),
new MultiTypeMethodCallGen(BuiltInMethods.ABS)) new MultiTypeMethodCallGen(BuiltInMethods.ABS))
addSqlFunction(
ABS,
Seq(BIG_DEC_TYPE_INFO),
new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
addSqlFunction( addSqlFunction(
FLOOR, FLOOR,
Seq(DOUBLE_TYPE_INFO), Seq(DOUBLE_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.FLOOR.method)) new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
addSqlFunction(
FLOOR,
Seq(BIG_DEC_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
addSqlFunction( addSqlFunction(
CEIL, CEIL,
Seq(DOUBLE_TYPE_INFO), Seq(DOUBLE_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.CEIL.method)) new FloorCeilCallGen(BuiltInMethod.CEIL.method))
addSqlFunction(
CEIL,
Seq(BIG_DEC_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.CEIL.method))
// ---------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------
......
...@@ -18,12 +18,23 @@ ...@@ -18,12 +18,23 @@
package org.apache.flink.api.table.codegen.calls package org.apache.flink.api.table.codegen.calls
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation} import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression} import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression}
import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isBoolean, isComparable, isDecimal, isNumeric}
object ScalarOperators { object ScalarOperators {
def generateStringConcatOperator(
nullCheck: Boolean,
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, STRING_TYPE_INFO, left, right) {
(leftTerm, rightTerm) => s"$leftTerm + $rightTerm"
}
}
def generateArithmeticOperator( def generateArithmeticOperator(
operator: String, operator: String,
nullCheck: Boolean, nullCheck: Boolean,
...@@ -31,40 +42,17 @@ object ScalarOperators { ...@@ -31,40 +42,17 @@ object ScalarOperators {
left: GeneratedExpression, left: GeneratedExpression,
right: GeneratedExpression) right: GeneratedExpression)
: GeneratedExpression = { : GeneratedExpression = {
// String arithmetic // TODO rework val leftCasting = numericCasting(left.resultType, resultType)
if (isString(left)) { val rightCasting = numericCasting(right.resultType, resultType)
generateOperatorIfNotNull(nullCheck, resultType, left, right) { val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
(leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
}
}
// Numeric arithmetic
else if (isNumeric(left) && isNumeric(right)) {
val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]]
val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]]
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
generateOperatorIfNotNull(nullCheck, resultType, left, right) { generateOperatorIfNotNull(nullCheck, resultType, left, right) {
(leftTerm, rightTerm) => (leftTerm, rightTerm) =>
// no casting required if (isDecimal(resultType)) {
if (leftType == resultType && rightType == resultType) { s"${leftCasting(leftTerm)}.${arithOpToDecMethod(operator)}(${rightCasting(rightTerm)})"
s"$leftTerm $operator $rightTerm" } else {
} s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})"
// left needs casting
else if (leftType != resultType && rightType == resultType) {
s"(($resultTypeTerm) $leftTerm) $operator $rightTerm"
}
// right needs casting
else if (leftType == resultType && rightType != resultType) {
s"$leftTerm $operator (($resultTypeTerm) $rightTerm)"
}
// both sides need casting
else {
s"(($resultTypeTerm) $leftTerm) $operator (($resultTypeTerm) $rightTerm)"
} }
}
}
else {
throw new CodeGenException("Unsupported arithmetic operation.")
} }
} }
...@@ -75,7 +63,16 @@ object ScalarOperators { ...@@ -75,7 +63,16 @@ object ScalarOperators {
operand: GeneratedExpression) operand: GeneratedExpression)
: GeneratedExpression = { : GeneratedExpression = {
generateUnaryOperatorIfNotNull(nullCheck, resultType, operand) { generateUnaryOperatorIfNotNull(nullCheck, resultType, operand) {
(operandTerm) => s"$operator($operandTerm)" (operandTerm) =>
if (isDecimal(operand.resultType) && operator == "-") {
s"$operandTerm.negate()"
} else if (isDecimal(operand.resultType) && operator == "+") {
s"$operandTerm"
} else if (isNumeric(operand.resultType)) {
s"$operator($operandTerm)"
} else {
throw new CodeGenException("Unsupported unary operator.")
}
} }
} }
...@@ -84,15 +81,27 @@ object ScalarOperators { ...@@ -84,15 +81,27 @@ object ScalarOperators {
left: GeneratedExpression, left: GeneratedExpression,
right: GeneratedExpression) right: GeneratedExpression)
: GeneratedExpression = { : GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { // numeric types
if (isReference(left)) { if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" generateComparison("==", nullCheck, left, right)
} }
else if (isReference(right)) { // comparable types of same type
(leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)" else if (isComparable(left.resultType) && left.resultType == right.resultType) {
} generateComparison("==", nullCheck, left, right)
else { }
(leftTerm, rightTerm) => s"$leftTerm == $rightTerm" // non comparable types
else {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)"
}
else {
throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
}
} }
} }
} }
...@@ -102,19 +111,34 @@ object ScalarOperators { ...@@ -102,19 +111,34 @@ object ScalarOperators {
left: GeneratedExpression, left: GeneratedExpression,
right: GeneratedExpression) right: GeneratedExpression)
: GeneratedExpression = { : GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { // numeric types
if (isReference(left)) { if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" generateComparison("!=", nullCheck, left, right)
} }
else if (isReference(right)) { // comparable types
(leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))" else if (isComparable(left.resultType) && left.resultType == right.resultType) {
} generateComparison("!=", nullCheck, left, right)
else { }
(leftTerm, rightTerm) => s"$leftTerm != $rightTerm" // non comparable types
else {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))"
}
else {
throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
}
} }
} }
} }
/**
* Generates comparison code for numeric types and comparable types of same type.
*/
def generateComparison( def generateComparison(
operator: String, operator: String,
nullCheck: Boolean, nullCheck: Boolean,
...@@ -122,14 +146,38 @@ object ScalarOperators { ...@@ -122,14 +146,38 @@ object ScalarOperators {
right: GeneratedExpression) right: GeneratedExpression)
: GeneratedExpression = { : GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isString(left) && isString(right)) { // left is decimal or both sides are decimal
(leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0" if (isDecimal(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => {
val operandCasting = numericCasting(right.resultType, left.resultType)
s"$leftTerm.compareTo(${operandCasting(rightTerm)}) $operator 0"
}
}
// right is decimal
else if (isNumeric(left.resultType) && isDecimal(right.resultType)) {
(leftTerm, rightTerm) => {
val operandCasting = numericCasting(left.resultType, right.resultType)
s"${operandCasting(leftTerm)}.compareTo($rightTerm) $operator 0"
}
} }
else if (isNumeric(left) && isNumeric(right)) { // both sides are numeric
else if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
} }
// both sides are boolean
else if (isBoolean(left.resultType) && left.resultType == right.resultType) {
operator match {
case "==" | "!=" => (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
case _ => throw new CodeGenException(s"Unsupported boolean comparison '$operator'.")
}
}
// both sides are same comparable type
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
(leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0"
}
else { else {
throw new CodeGenException("Comparison is only supported for Strings and numeric types.") throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
} }
} }
} }
...@@ -147,7 +195,7 @@ object ScalarOperators { ...@@ -147,7 +195,7 @@ object ScalarOperators {
|boolean $nullTerm = false; |boolean $nullTerm = false;
|""".stripMargin |""".stripMargin
} }
else if (!nullCheck && isReference(operand.resultType)) { else if (!nullCheck && isReference(operand)) {
s""" s"""
|${operand.code} |${operand.code}
|boolean $resultTerm = ${operand.resultTerm} == null; |boolean $resultTerm = ${operand.resultTerm} == null;
...@@ -177,7 +225,7 @@ object ScalarOperators { ...@@ -177,7 +225,7 @@ object ScalarOperators {
|boolean $nullTerm = false; |boolean $nullTerm = false;
|""".stripMargin |""".stripMargin
} }
else if (!nullCheck && isReference(operand.resultType)) { else if (!nullCheck && isReference(operand)) {
s""" s"""
|${operand.code} |${operand.code}
|boolean $resultTerm = ${operand.resultTerm} != null; |boolean $resultTerm = ${operand.resultTerm} != null;
...@@ -326,63 +374,72 @@ object ScalarOperators { ...@@ -326,63 +374,72 @@ object ScalarOperators {
nullCheck: Boolean, nullCheck: Boolean,
operand: GeneratedExpression, operand: GeneratedExpression,
targetType: TypeInformation[_]) targetType: TypeInformation[_])
: GeneratedExpression = { : GeneratedExpression = (operand.resultType, targetType) match {
targetType match { // identity casting
// identity casting case (fromTp, toTp) if fromTp == toTp =>
case operand.resultType => operand
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm" // * -> String
} case (_, STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s""" "" + $operandTerm"""
}
// * -> String // * -> Character
case STRING_TYPE_INFO => case (_, CHAR_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { throw new CodeGenException("Character type not supported.")
(operandTerm) => s""" "" + $operandTerm"""
}
// * -> Date // String -> NUMERIC TYPE (not Character), Boolean
case DATE_TYPE_INFO => case (STRING_TYPE_INFO, _: NumericTypeInfo[_])
throw new CodeGenException("Date type not supported yet.") | (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) =>
val wrapperClass = targetType.getTypeClass.getCanonicalName
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$wrapperClass.valueOf($operandTerm)"
}
// * -> Void // String -> BigDecimal
case VOID_TYPE_INFO => case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) =>
throw new CodeGenException("Void type not supported.") val wrapperClass = targetType.getTypeClass.getCanonicalName
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"new $wrapperClass($operandTerm)"
}
// * -> Character // Boolean -> NUMERIC TYPE
case CHAR_TYPE_INFO => case (BOOLEAN_TYPE_INFO, nti: NumericTypeInfo[_]) =>
throw new CodeGenException("Character type not supported.") val targetTypeTerm = primitiveTypeTermForTypeInfo(nti)
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)"
}
// NUMERIC TYPE -> Boolean // Boolean -> BigDecimal
case BOOLEAN_TYPE_INFO if isNumeric(operand) => case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm != 0" (operandTerm) => s"$operandTerm ? java.math.BigDecimal.ONE : java.math.BigDecimal.ZERO"
} }
// String -> BASIC TYPE (not String, Date, Void, Character) // NUMERIC TYPE -> Boolean
case ti: BasicTypeInfo[_] if isString(operand) => case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) =>
val wrapperClass = targetType.getTypeClass.getCanonicalName generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { (operandTerm) => s"$operandTerm != 0"
(operandTerm) => s"$wrapperClass.valueOf($operandTerm)" }
}
// NUMERIC TYPE -> NUMERIC TYPE // BigDecimal -> Boolean
case nti: NumericTypeInfo[_] if isNumeric(operand) => case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) =>
val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { (operandTerm) => s"$operandTerm.compareTo(java.math.BigDecimal.ZERO) != 0"
(operandTerm) => s"($targetTypeTerm) $operandTerm" }
}
// Boolean -> NUMERIC TYPE // NUMERIC TYPE, BigDecimal -> NUMERIC TYPE, BigDecimal
case nti: NumericTypeInfo[_] if isBoolean(operand) => case (_: NumericTypeInfo[_], _: NumericTypeInfo[_])
val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) | (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_])
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { | (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) =>
(operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" val operandCasting = numericCasting(operand.resultType, targetType)
} generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"${operandCasting(operandTerm)}"
}
case _ => case (from, to) =>
throw new CodeGenException(s"Unsupported cast from '${operand.resultType}'" + throw new CodeGenException(s"Unsupported cast from '$from' to '$to'.")
s"to '$targetType'.")
}
} }
def generateIfElse( def generateIfElse(
...@@ -519,4 +576,51 @@ object ScalarOperators { ...@@ -519,4 +576,51 @@ object ScalarOperators {
GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) GeneratedExpression(resultTerm, nullTerm, resultCode, resultType)
} }
private def arithOpToDecMethod(operator: String): String = operator match {
case "+" => "add"
case "-" => "subtract"
case "*" => "multiply"
case "/" => "divide"
case "%" => "remainder"
case _ => throw new CodeGenException("Unsupported decimal arithmetic operator.")
}
private def numericCasting(
operandType: TypeInformation[_],
resultType: TypeInformation[_])
: (String) => String = {
def decToPrimMethod(targetType: TypeInformation[_]): String = targetType match {
case BYTE_TYPE_INFO => "byteValueExact"
case SHORT_TYPE_INFO => "shortValueExact"
case INT_TYPE_INFO => "intValueExact"
case LONG_TYPE_INFO => "longValueExact"
case FLOAT_TYPE_INFO => "floatValue"
case DOUBLE_TYPE_INFO => "doubleValue"
case _ => throw new CodeGenException("Unsupported decimal casting type.")
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
// no casting necessary
if (operandType == resultType) {
(operandTerm) => s"$operandTerm"
}
// result type is decimal but numeric operand is not
else if (isDecimal(resultType) && !isDecimal(operandType) && isNumeric(operandType)) {
(operandTerm) =>
s"java.math.BigDecimal.valueOf((${superPrimitive(operandType)}) $operandTerm)"
}
// numeric result type is not decimal but operand is
else if (isNumeric(resultType) && !isDecimal(resultType) && isDecimal(operandType) ) {
(operandTerm) => s"$operandTerm.${decToPrimMethod(resultType)}()"
}
// result type and operand type are numeric but not decimal
else if (isNumeric(operandType) && isNumeric(resultType)
&& !isDecimal(operandType) && !isDecimal(resultType)) {
(operandTerm) => s"(($resultTypeTerm) $operandTerm)"
}
else {
throw new CodeGenException(s"Unsupported casting from $operandType to $resultType.")
}
}
} }
...@@ -71,23 +71,24 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { ...@@ -71,23 +71,24 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
"DOUBLE" ^^ { ti => BasicTypeInfo.DOUBLE_TYPE_INFO } | "DOUBLE" ^^ { ti => BasicTypeInfo.DOUBLE_TYPE_INFO } |
("BOOL" | "BOOLEAN" ) ^^ { ti => BasicTypeInfo.BOOLEAN_TYPE_INFO } | ("BOOL" | "BOOLEAN" ) ^^ { ti => BasicTypeInfo.BOOLEAN_TYPE_INFO } |
"STRING" ^^ { ti => BasicTypeInfo.STRING_TYPE_INFO } | "STRING" ^^ { ti => BasicTypeInfo.STRING_TYPE_INFO } |
"DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO } "DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO } |
"DECIMAL" ^^ { ti => BasicTypeInfo.BIG_DEC_TYPE_INFO }
// Literals // Literals
lazy val numberLiteral: PackratParser[Expression] = lazy val numberLiteral: PackratParser[Expression] =
((wholeNumber <~ ("L" | "l")) | floatingPointNumber | decimalNumber | wholeNumber) ^^ { (wholeNumber <~ ("l" | "L")) ^^ { n => Literal(n.toLong) } |
str => (decimalNumber <~ ("p" | "P")) ^^ { n => Literal(BigDecimal(n)) } |
if (str.endsWith("L") || str.endsWith("l")) { (floatingPointNumber | decimalNumber) ^^ {
Literal(str.toLong) n =>
} else if (str.matches("""-?\d+""")) { if (n.matches("""-?\d+""")) {
Literal(str.toInt) Literal(n.toInt)
} else if (str.endsWith("f") | str.endsWith("F")) { } else if (n.endsWith("f") || n.endsWith("F")) {
Literal(str.toFloat) Literal(n.toFloat)
} else { } else {
Literal(str.toDouble) Literal(n.toDouble)
} }
} }
lazy val singleQuoteStringLiteral: Parser[Expression] = lazy val singleQuoteStringLiteral: Parser[Expression] =
("'" + """([^'\p{Cntrl}\\]|\\[\\'"bfnrt]|\\u[a-fA-F0-9]{4})*""" + "'").r ^^ { ("'" + """([^'\p{Cntrl}\\]|\\[\\'"bfnrt]|\\u[a-fA-F0-9]{4})*""" + "'").r ^^ {
...@@ -261,7 +262,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { ...@@ -261,7 +262,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val unaryMinus: PackratParser[Expression] = "-" ~> composite ^^ { e => UnaryMinus(e) } lazy val unaryMinus: PackratParser[Expression] = "-" ~> composite ^^ { e => UnaryMinus(e) }
lazy val unary = unaryNot | unaryMinus | composite lazy val unary = composite | unaryNot | unaryMinus
// arithmetic // arithmetic
......
...@@ -17,18 +17,17 @@ ...@@ -17,18 +17,17 @@
*/ */
package org.apache.flink.api.table.expressions package org.apache.flink.api.table.expressions
import scala.collection.JavaConversions._
import org.apache.calcite.rex.RexNode import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation} import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString}
import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion, TypeConverter} import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion}
import org.apache.flink.api.table.validate._ import org.apache.flink.api.table.validate._
import scala.collection.JavaConversions._
abstract class BinaryArithmetic extends BinaryExpression { abstract class BinaryArithmetic extends BinaryExpression {
def sqlOperator: SqlOperator def sqlOperator: SqlOperator
...@@ -45,9 +44,8 @@ abstract class BinaryArithmetic extends BinaryExpression { ...@@ -45,9 +44,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
// TODO: tighten this rule once we implemented type coercion rules during validation // TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = { override def validateInput(): ExprValidationResult = {
if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] || if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
!right.resultType.isInstanceOf[NumericTypeInfo[_]]) { ValidationFailure(s"$this requires both operands Numeric, got " +
ValidationFailure(s"$this requires both operands Numeric, get" +
s"${left.resultType} and ${right.resultType}") s"${left.resultType} and ${right.resultType}")
} else { } else {
ValidationSuccess ValidationSuccess
...@@ -61,28 +59,24 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic { ...@@ -61,28 +59,24 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
val sqlOperator = SqlStdOperatorTable.PLUS val sqlOperator = SqlStdOperatorTable.PLUS
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val l = left.toRexNode if(isString(left.resultType)) {
val r = right.toRexNode val castedRight = Cast(right, BasicTypeInfo.STRING_TYPE_INFO)
if(SqlTypeName.STRING_TYPES.contains(l.getType.getSqlTypeName)) { relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, castedRight.toRexNode)
val cast: RexNode = relBuilder.cast(r, } else if(isString(right.resultType)) {
TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) val castedLeft = Cast(left, BasicTypeInfo.STRING_TYPE_INFO)
relBuilder.call(SqlStdOperatorTable.PLUS, l, cast) relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, right.toRexNode)
} else if(SqlTypeName.STRING_TYPES.contains(r.getType.getSqlTypeName)) {
val cast: RexNode = relBuilder.cast(l,
TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO))
relBuilder.call(SqlStdOperatorTable.PLUS, cast, r)
} else { } else {
relBuilder.call(SqlStdOperatorTable.PLUS, l, r) val castedLeft = Cast(left, resultType)
val castedRight = Cast(right, resultType)
relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, castedRight.toRexNode)
} }
} }
// TODO: tighten this rule once we implemented type coercion rules during validation // TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = { override def validateInput(): ExprValidationResult = {
if (left.resultType == BasicTypeInfo.STRING_TYPE_INFO || if (isString(left.resultType) || isString(right.resultType)) {
right.resultType == BasicTypeInfo.STRING_TYPE_INFO) {
ValidationSuccess ValidationSuccess
} else if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] || } else if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
!right.resultType.isInstanceOf[NumericTypeInfo[_]]) {
ValidationFailure(s"$this requires Numeric or String input," + ValidationFailure(s"$this requires Numeric or String input," +
s" get ${left.resultType} and ${right.resultType}") s" get ${left.resultType} and ${right.resultType}")
} else { } else {
......
...@@ -17,17 +17,16 @@ ...@@ -17,17 +17,16 @@
*/ */
package org.apache.flink.api.table.expressions package org.apache.flink.api.table.expressions
import scala.collection.JavaConversions._
import org.apache.calcite.rex.RexNode import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.NumericTypeInfo import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isComparable, isNumeric}
import org.apache.flink.api.table.validate._ import org.apache.flink.api.table.validate._
import scala.collection.JavaConversions._
abstract class BinaryComparison extends BinaryExpression { abstract class BinaryComparison extends BinaryExpression {
def sqlOperator: SqlOperator def sqlOperator: SqlOperator
...@@ -39,11 +38,12 @@ abstract class BinaryComparison extends BinaryExpression { ...@@ -39,11 +38,12 @@ abstract class BinaryComparison extends BinaryExpression {
// TODO: tighten this rule once we implemented type coercion rules during validation // TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (STRING_TYPE_INFO, STRING_TYPE_INFO) => ValidationSuccess case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess case (lType, rType) if isComparable(lType) && lType == rType => ValidationSuccess
case (lType, rType) => case (lType, rType) =>
ValidationFailure( ValidationFailure(
s"Comparison is only supported for Strings and numeric types, get $lType and $rType") s"Comparison is only supported for numeric types and comparable types of same type," +
s"got $lType and $rType")
} }
} }
...@@ -53,13 +53,11 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison ...@@ -53,13 +53,11 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val sqlOperator: SqlOperator = SqlStdOperatorTable.EQUALS val sqlOperator: SqlOperator = SqlStdOperatorTable.EQUALS
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
// TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) => case (lType, rType) =>
if (lType != rType) { ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
} else {
ValidationSuccess
}
} }
} }
...@@ -69,13 +67,11 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari ...@@ -69,13 +67,11 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari
val sqlOperator: SqlOperator = SqlStdOperatorTable.NOT_EQUALS val sqlOperator: SqlOperator = SqlStdOperatorTable.NOT_EQUALS
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match { override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
// TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) => case (lType, rType) =>
if (lType != rType) { ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType")
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
} else {
ValidationSuccess
}
} }
} }
......
...@@ -20,6 +20,7 @@ package org.apache.flink.api.table.expressions ...@@ -20,6 +20,7 @@ package org.apache.flink.api.table.expressions
import java.util.Date import java.util.Date
import org.apache.calcite.rex.RexNode import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.table.typeutils.TypeConverter import org.apache.flink.api.table.typeutils.TypeConverter
...@@ -35,6 +36,9 @@ object Literal { ...@@ -35,6 +36,9 @@ object Literal {
case str: String => Literal(str, BasicTypeInfo.STRING_TYPE_INFO) case str: String => Literal(str, BasicTypeInfo.STRING_TYPE_INFO)
case bool: Boolean => Literal(bool, BasicTypeInfo.BOOLEAN_TYPE_INFO) case bool: Boolean => Literal(bool, BasicTypeInfo.BOOLEAN_TYPE_INFO)
case date: Date => Literal(date, BasicTypeInfo.DATE_TYPE_INFO) case date: Date => Literal(date, BasicTypeInfo.DATE_TYPE_INFO)
case javaDec: java.math.BigDecimal => Literal(javaDec, BasicTypeInfo.BIG_DEC_TYPE_INFO)
case scalaDec: scala.math.BigDecimal =>
Literal(scalaDec.bigDecimal, BasicTypeInfo.BIG_DEC_TYPE_INFO)
} }
} }
...@@ -42,7 +46,13 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre ...@@ -42,7 +46,13 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre
override def toString = s"$value" override def toString = s"$value"
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.literal(value) resultType match {
case BasicTypeInfo.BIG_DEC_TYPE_INFO =>
val bigDecValue = value.asInstanceOf[java.math.BigDecimal]
val decType = relBuilder.getTypeFactory.createSqlType(SqlTypeName.DECIMAL)
relBuilder.getRexBuilder.makeExactLiteral(bigDecValue, decType)
case _ => relBuilder.literal(value)
}
} }
} }
......
...@@ -61,6 +61,7 @@ trait DataSetRel extends RelNode with FlinkRel { ...@@ -61,6 +61,7 @@ trait DataSetRel extends RelNode with FlinkRel {
case SqlTypeName.DOUBLE => s + 8 case SqlTypeName.DOUBLE => s + 8
case SqlTypeName.VARCHAR => s + 12 case SqlTypeName.VARCHAR => s + 12
case SqlTypeName.CHAR => s + 1 case SqlTypeName.CHAR => s + 1
case SqlTypeName.DECIMAL => s + 12
case _ => throw new TableException("Unsupported data type encountered") case _ => throw new TableException("Unsupported data type encountered")
} }
} }
......
...@@ -155,6 +155,8 @@ object AggregateUtil { ...@@ -155,6 +155,8 @@ object AggregateUtil {
new FloatSumAggregate new FloatSumAggregate
case DOUBLE => case DOUBLE =>
new DoubleSumAggregate new DoubleSumAggregate
case DECIMAL =>
new DecimalSumAggregate
case sqlType: SqlTypeName => case sqlType: SqlTypeName =>
throw new TableException("Sum aggregate does no support type:" + sqlType) throw new TableException("Sum aggregate does no support type:" + sqlType)
} }
...@@ -173,6 +175,8 @@ object AggregateUtil { ...@@ -173,6 +175,8 @@ object AggregateUtil {
new FloatAvgAggregate new FloatAvgAggregate
case DOUBLE => case DOUBLE =>
new DoubleAvgAggregate new DoubleAvgAggregate
case DECIMAL =>
new DecimalAvgAggregate
case sqlType: SqlTypeName => case sqlType: SqlTypeName =>
throw new TableException("Avg aggregate does no support type:" + sqlType) throw new TableException("Avg aggregate does no support type:" + sqlType)
} }
...@@ -192,6 +196,8 @@ object AggregateUtil { ...@@ -192,6 +196,8 @@ object AggregateUtil {
new FloatMinAggregate new FloatMinAggregate
case DOUBLE => case DOUBLE =>
new DoubleMinAggregate new DoubleMinAggregate
case DECIMAL =>
new DecimalMinAggregate
case BOOLEAN => case BOOLEAN =>
new BooleanMinAggregate new BooleanMinAggregate
case sqlType: SqlTypeName => case sqlType: SqlTypeName =>
...@@ -211,6 +217,8 @@ object AggregateUtil { ...@@ -211,6 +217,8 @@ object AggregateUtil {
new FloatMaxAggregate new FloatMaxAggregate
case DOUBLE => case DOUBLE =>
new DoubleMaxAggregate new DoubleMaxAggregate
case DECIMAL =>
new DecimalMaxAggregate
case BOOLEAN => case BOOLEAN =>
new BooleanMaxAggregate new BooleanMaxAggregate
case sqlType: SqlTypeName => case sqlType: SqlTypeName =>
......
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import com.google.common.math.LongMath import com.google.common.math.LongMath
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
import java.math.BigDecimal
import java.math.BigInteger import java.math.BigInteger
abstract class AvgAggregate[T] extends Aggregate[T] { abstract class AvgAggregate[T] extends Aggregate[T] {
...@@ -251,3 +252,45 @@ class DoubleAvgAggregate extends FloatingAvgAggregate[Double] { ...@@ -251,3 +252,45 @@ class DoubleAvgAggregate extends FloatingAvgAggregate[Double] {
} }
} }
} }
class DecimalAvgAggregate extends AvgAggregate[BigDecimal] {
override def intermediateDataType = Array(
BasicTypeInfo.BIG_DEC_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
override def initiate(partial: Row): Unit = {
partial.setField(partialSumIndex, BigDecimal.ZERO)
partial.setField(partialCountIndex, 0L)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
val input = value.asInstanceOf[BigDecimal]
partial.setField(partialSumIndex, input)
partial.setField(partialCountIndex, 1L)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialSum = partial.productElement(partialSumIndex).asInstanceOf[BigDecimal]
val partialCount = partial.productElement(partialCountIndex).asInstanceOf[Long]
val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal]
val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
buffer.setField(partialSumIndex, partialSum.add(bufferSum))
buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
}
override def evaluate(buffer: Row): BigDecimal = {
val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
if (bufferCount != 0) {
val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal]
bufferSum.divide(BigDecimal.valueOf(bufferCount))
} else {
null.asInstanceOf[BigDecimal]
}
}
}
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
*/ */
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
...@@ -125,3 +127,45 @@ class BooleanMaxAggregate extends MaxAggregate[Boolean] { ...@@ -125,3 +127,45 @@ class BooleanMaxAggregate extends MaxAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO) override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
} }
class DecimalMaxAggregate extends Aggregate[BigDecimal] {
protected var minIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(minIndex, null)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
partial.setField(minIndex, value)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
val min = if (partialValue.compareTo(bufferValue) > 0) partialValue else bufferValue
buffer.setField(minIndex, min)
} else {
buffer.setField(minIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(minIndex).asInstanceOf[BigDecimal]
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
minIndex = aggOffset
}
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
*/ */
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
...@@ -125,3 +126,45 @@ class BooleanMinAggregate extends MinAggregate[Boolean] { ...@@ -125,3 +126,45 @@ class BooleanMinAggregate extends MinAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO) override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
} }
class DecimalMinAggregate extends Aggregate[BigDecimal] {
protected var minIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(minIndex, null)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
partial.setField(minIndex, value)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
val min = if (partialValue.compareTo(bufferValue) < 0) partialValue else bufferValue
buffer.setField(minIndex, min)
} else {
buffer.setField(minIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(minIndex).asInstanceOf[BigDecimal]
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
minIndex = aggOffset
}
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
*/ */
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
...@@ -85,3 +86,45 @@ class FloatSumAggregate extends SumAggregate[Float] { ...@@ -85,3 +86,45 @@ class FloatSumAggregate extends SumAggregate[Float] {
class DoubleSumAggregate extends SumAggregate[Double] { class DoubleSumAggregate extends SumAggregate[Double] {
override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
} }
class DecimalSumAggregate extends Aggregate[BigDecimal] {
protected var sumIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(partial: Row): Unit = {
partial.setField(sumIndex, null)
}
override def merge(partial1: Row, buffer: Row): Unit = {
val partialValue = partial1.productElement(sumIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(sumIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
buffer.setField(sumIndex, partialValue.add(bufferValue))
} else {
buffer.setField(sumIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(sumIndex).asInstanceOf[BigDecimal]
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
val input = value.asInstanceOf[BigDecimal]
partial.setField(sumIndex, input)
}
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
sumIndex = aggOffset
}
}
...@@ -17,17 +17,37 @@ ...@@ -17,17 +17,37 @@
*/ */
package org.apache.flink.api.table.typeutils package org.apache.flink.api.table.typeutils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.validate._ import org.apache.flink.api.table.validate._
object TypeCheckUtils { object TypeCheckUtils {
def assertNumericExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = { def isNumeric(dataType: TypeInformation[_]): Boolean = dataType match {
if (dataType.isInstanceOf[NumericTypeInfo[_]]) { case _: NumericTypeInfo[_] => true
case BIG_DEC_TYPE_INFO => true
case _ => false
}
def isString(dataType: TypeInformation[_]): Boolean = dataType == STRING_TYPE_INFO
def isBoolean(dataType: TypeInformation[_]): Boolean = dataType == BOOLEAN_TYPE_INFO
def isDecimal(dataType: TypeInformation[_]): Boolean = dataType == BIG_DEC_TYPE_INFO
def isComparable(dataType: TypeInformation[_]): Boolean =
classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass)
def assertNumericExpr(
dataType: TypeInformation[_],
caller: String)
: ExprValidationResult = dataType match {
case _: NumericTypeInfo[_] =>
ValidationSuccess ValidationSuccess
} else { case BIG_DEC_TYPE_INFO =>
ValidationSuccess
case _ =>
ValidationFailure(s"$caller requires numeric types, get $dataType here") ValidationFailure(s"$caller requires numeric types, get $dataType here")
}
} }
def assertOrderableExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = { def assertOrderableExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = {
......
...@@ -37,11 +37,14 @@ object TypeCoercion { ...@@ -37,11 +37,14 @@ object TypeCoercion {
def widerTypeOf(tp1: TypeInformation[_], tp2: TypeInformation[_]): Option[TypeInformation[_]] = { def widerTypeOf(tp1: TypeInformation[_], tp2: TypeInformation[_]): Option[TypeInformation[_]] = {
(tp1, tp2) match { (tp1, tp2) match {
case (tp1, tp2) if tp1 == tp2 => Some(tp1) case (ti1, ti2) if ti1 == ti2 => Some(ti1)
case (_, STRING_TYPE_INFO) => Some(STRING_TYPE_INFO) case (_, STRING_TYPE_INFO) => Some(STRING_TYPE_INFO)
case (STRING_TYPE_INFO, _) => Some(STRING_TYPE_INFO) case (STRING_TYPE_INFO, _) => Some(STRING_TYPE_INFO)
case (_, BIG_DEC_TYPE_INFO) => Some(BIG_DEC_TYPE_INFO)
case (BIG_DEC_TYPE_INFO, _) => Some(BIG_DEC_TYPE_INFO)
case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) =>
val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2) val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2)
Some(numericWideningPrecedence(higherIndex)) Some(numericWideningPrecedence(higherIndex))
...@@ -56,6 +59,8 @@ object TypeCoercion { ...@@ -56,6 +59,8 @@ object TypeCoercion {
def canSafelyCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match { def canSafelyCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match {
case (_, STRING_TYPE_INFO) => true case (_, STRING_TYPE_INFO) => true
case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true
case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) =>
if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) { if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) {
true true
...@@ -71,21 +76,24 @@ object TypeCoercion { ...@@ -71,21 +76,24 @@ object TypeCoercion {
* Note: This may lose information during the cast. * Note: This may lose information during the cast.
*/ */
def canCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match { def canCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match {
case (from, to) if from == to => true case (fromTp, toTp) if fromTp == toTp => true
case (_, STRING_TYPE_INFO) => true case (_, STRING_TYPE_INFO) => true
case (_, DATE_TYPE_INFO) => false // Date type not supported yet.
case (_, VOID_TYPE_INFO) => false // Void type not supported
case (_, CHAR_TYPE_INFO) => false // Character type not supported. case (_, CHAR_TYPE_INFO) => false // Character type not supported.
case (STRING_TYPE_INFO, _: NumericTypeInfo[_]) => true case (STRING_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) => true case (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) => true
case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) => true
case (BOOLEAN_TYPE_INFO, _: NumericTypeInfo[_]) => true case (BOOLEAN_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) => true
case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) => true case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) => true
case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) => true
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => true case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => true
case (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true
case _ => false case _ => false
} }
......
...@@ -56,6 +56,7 @@ object TypeConverter { ...@@ -56,6 +56,7 @@ object TypeConverter {
case STRING_TYPE_INFO => VARCHAR case STRING_TYPE_INFO => VARCHAR
case STRING_VALUE_TYPE_INFO => VARCHAR case STRING_VALUE_TYPE_INFO => VARCHAR
case DATE_TYPE_INFO => DATE case DATE_TYPE_INFO => DATE
case BIG_DEC_TYPE_INFO => DECIMAL
case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO => case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO =>
throw new TableException("Character type is not supported.") throw new TableException("Character type is not supported.")
...@@ -74,6 +75,7 @@ object TypeConverter { ...@@ -74,6 +75,7 @@ object TypeConverter {
case DOUBLE => DOUBLE_TYPE_INFO case DOUBLE => DOUBLE_TYPE_INFO
case VARCHAR | CHAR => STRING_TYPE_INFO case VARCHAR | CHAR => STRING_TYPE_INFO
case DATE => DATE_TYPE_INFO case DATE => DATE_TYPE_INFO
case DECIMAL => BIG_DEC_TYPE_INFO
case NULL => case NULL =>
throw new TableException("Type NULL is not supported. " + throw new TableException("Type NULL is not supported. " +
......
...@@ -98,7 +98,9 @@ class AggregationsITCase( ...@@ -98,7 +98,9 @@ class AggregationsITCase(
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config) val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery = "SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7)" + val sqlQuery =
"SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7), " +
" sum(CAST(_6 AS DECIMAL))" +
"FROM MyTable" "FROM MyTable"
val ds = env.fromElements( val ds = env.fromElements(
...@@ -108,7 +110,7 @@ class AggregationsITCase( ...@@ -108,7 +110,7 @@ class AggregationsITCase(
val result = tEnv.sql(sqlQuery) val result = tEnv.sql(sqlQuery)
val expected = "1,1,1,1,1.5,1.5,2" val expected = "1,1,1,1,1.5,1.5,2,3.0"
val results = result.toDataSet[Row].collect() val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
......
...@@ -157,6 +157,23 @@ class ExpressionsITCase( ...@@ -157,6 +157,23 @@ class ExpressionsITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
@Test
def testDecimalLiteral(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val t = env
.fromElements(
(BigDecimal("78.454654654654654").bigDecimal, BigDecimal("4E+9999").bigDecimal)
)
.toTable(tEnv, 'a, 'b)
.select('a, 'b, BigDecimal("11.2"), BigDecimal("11.2").bigDecimal)
val expected = "78.454654654654654,4E+9999,11.2,11.2"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
// Date literals not yet supported // Date literals not yet supported
@Ignore @Ignore
@Test @Test
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.table.expressions
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.Row
import org.apache.flink.api.table.expressions.utils.ExpressionTestBase
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.junit.Test
class DecimalTypeTest extends ExpressionTestBase {
@Test
def testDecimalLiterals(): Unit = {
// implicit double
testAllApis(
11.2,
"11.2",
"11.2",
"11.2")
// implicit double
testAllApis(
0.7623533651719233,
"0.7623533651719233",
"0.7623533651719233",
"0.7623533651719233")
// explicit decimal (with precision of 19)
testAllApis(
BigDecimal("1234567891234567891"),
"1234567891234567891p",
"1234567891234567891",
"1234567891234567891")
// explicit decimal (high precision, not SQL compliant)
testTableApi(
BigDecimal("123456789123456789123456789"),
"123456789123456789123456789p",
"123456789123456789123456789")
// explicit decimal (high precision, not SQL compliant)
testTableApi(
BigDecimal("12.3456789123456789123456789"),
"12.3456789123456789123456789p",
"12.3456789123456789123456789")
}
@Test
def testDecimalBorders(): Unit = {
testAllApis(
Double.MaxValue,
Double.MaxValue.toString,
Double.MaxValue.toString,
Double.MaxValue.toString)
testAllApis(
Double.MinValue,
Double.MinValue.toString,
Double.MinValue.toString,
Double.MinValue.toString)
testAllApis(
Double.MinValue.cast(FLOAT_TYPE_INFO),
s"${Double.MinValue}.cast(FLOAT)",
s"CAST(${Double.MinValue} AS FLOAT)",
Float.NegativeInfinity.toString)
testAllApis(
Byte.MinValue.cast(BYTE_TYPE_INFO),
s"(${Byte.MinValue}).cast(BYTE)",
s"CAST(${Byte.MinValue} AS TINYINT)",
Byte.MinValue.toString)
testAllApis(
Byte.MinValue.cast(BYTE_TYPE_INFO) - 1.cast(BYTE_TYPE_INFO),
s"(${Byte.MinValue}).cast(BYTE) - (1).cast(BYTE)",
s"CAST(${Byte.MinValue} AS TINYINT) - CAST(1 AS TINYINT)",
Byte.MaxValue.toString)
testAllApis(
Short.MinValue.cast(SHORT_TYPE_INFO),
s"(${Short.MinValue}).cast(SHORT)",
s"CAST(${Short.MinValue} AS SMALLINT)",
Short.MinValue.toString)
testAllApis(
Int.MinValue.cast(INT_TYPE_INFO) - 1,
s"(${Int.MinValue}).cast(INT) - 1",
s"CAST(${Int.MinValue} AS INT) - 1",
Int.MaxValue.toString)
testAllApis(
Long.MinValue.cast(LONG_TYPE_INFO),
s"(${Long.MinValue}L).cast(LONG)",
s"CAST(${Long.MinValue} AS BIGINT)",
Long.MinValue.toString)
}
@Test
def testDecimalCasting(): Unit = {
// from String
testTableApi(
"123456789123456789123456789".cast(BIG_DEC_TYPE_INFO),
"'123456789123456789123456789'.cast(DECIMAL)",
"123456789123456789123456789")
// from double
testAllApis(
'f3.cast(BIG_DEC_TYPE_INFO),
"f3.cast(DECIMAL)",
"CAST(f3 AS DECIMAL)",
"4.2")
// to double
testAllApis(
'f0.cast(DOUBLE_TYPE_INFO),
"f0.cast(DOUBLE)",
"CAST(f0 AS DOUBLE)",
"1.2345678912345679E8")
// to int
testAllApis(
'f4.cast(INT_TYPE_INFO),
"f4.cast(INT)",
"CAST(f4 AS INT)",
"123456789")
// to long
testAllApis(
'f4.cast(LONG_TYPE_INFO),
"f4.cast(LONG)",
"CAST(f4 AS BIGINT)",
"123456789")
// to boolean (not SQL compliant)
testTableApi(
'f1.cast(BOOLEAN_TYPE_INFO),
"f1.cast(BOOL)",
"true")
testTableApi(
'f5.cast(BOOLEAN_TYPE_INFO),
"f5.cast(BOOL)",
"false")
testTableApi(
BigDecimal("123456789.123456789123456789").cast(DOUBLE_TYPE_INFO),
"(123456789.123456789123456789p).cast(DOUBLE)",
"1.2345678912345679E8")
}
@Test
def testDecimalArithmetic(): Unit = {
// implicit cast to decimal
testAllApis(
'f1 + 12,
"f1 + 12",
"f1 + 12",
"123456789123456789123456801")
// implicit cast to decimal
testAllApis(
Literal(12) + 'f1,
"12 + f1",
"12 + f1",
"123456789123456789123456801")
// implicit cast to decimal
testAllApis(
'f1 + 12.3,
"f1 + 12.3",
"f1 + 12.3",
"123456789123456789123456801.3")
// implicit cast to decimal
testAllApis(
Literal(12.3) + 'f1,
"12.3 + f1",
"12.3 + f1",
"123456789123456789123456801.3")
testAllApis(
'f1 + 'f1,
"f1 + f1",
"f1 + f1",
"246913578246913578246913578")
testAllApis(
'f1 - 'f1,
"f1 - f1",
"f1 - f1",
"0")
testAllApis(
'f1 * 'f1,
"f1 * f1",
"f1 * f1",
"15241578780673678546105778281054720515622620750190521")
testAllApis(
'f1 / 'f1,
"f1 / f1",
"f1 / f1",
"1")
testAllApis(
'f1 % 'f1,
"f1 % f1",
"MOD(f1, f1)",
"0")
testAllApis(
-'f0,
"-f0",
"-f0",
"-123456789.123456789123456789")
}
@Test
def testDecimalComparison(): Unit = {
testAllApis(
'f1 < 12,
"f1 < 12",
"f1 < 12",
"false")
testAllApis(
'f1 > 12,
"f1 > 12",
"f1 > 12",
"true")
testAllApis(
'f1 === 12,
"f1 === 12",
"f1 = 12",
"false")
testAllApis(
'f5 === 0,
"f5 === 0",
"f5 = 0",
"true")
testAllApis(
'f1 === BigDecimal("123456789123456789123456789"),
"f1 === 123456789123456789123456789p",
"f1 = CAST('123456789123456789123456789' AS DECIMAL)",
"true")
testAllApis(
'f1 !== BigDecimal("123456789123456789123456789"),
"f1 !== 123456789123456789123456789p",
"f1 <> CAST('123456789123456789123456789' AS DECIMAL)",
"false")
testAllApis(
'f4 < 'f0,
"f4 < f0",
"f4 < f0",
"true")
// TODO add all tests if FLINK-4070 is fixed
testSqlApi(
"12 < f1",
"true")
}
// ----------------------------------------------------------------------------------------------
def testData = {
val testData = new Row(6)
testData.setField(0, BigDecimal("123456789.123456789123456789").bigDecimal)
testData.setField(1, BigDecimal("123456789123456789123456789").bigDecimal)
testData.setField(2, 42)
testData.setField(3, 4.2)
testData.setField(4, BigDecimal("123456789").bigDecimal)
testData.setField(5, BigDecimal("0.000").bigDecimal)
testData
}
def typeInfo = {
new RowTypeInfo(Seq(
BIG_DEC_TYPE_INFO,
BIG_DEC_TYPE_INFO,
INT_TYPE_INFO,
DOUBLE_TYPE_INFO,
BIG_DEC_TYPE_INFO,
BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]]
}
}
...@@ -20,15 +20,13 @@ package org.apache.flink.api.table.expressions ...@@ -20,15 +20,13 @@ package org.apache.flink.api.table.expressions
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.expressions.utils.ExpressionEvaluator
import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
import org.apache.flink.api.table.expressions.{Expression, ExpressionParser} import org.apache.flink.api.table.expressions.utils.ExpressionTestBase
import org.apache.flink.api.table.typeutils.RowTypeInfo import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
class ScalarFunctionsTest { class ScalarFunctionsTest extends ExpressionTestBase {
// ---------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------
// String functions // String functions
...@@ -36,19 +34,19 @@ class ScalarFunctionsTest { ...@@ -36,19 +34,19 @@ class ScalarFunctionsTest {
@Test @Test
def testSubstring(): Unit = { def testSubstring(): Unit = {
testFunction( testAllApis(
'f0.substring(2), 'f0.substring(2),
"f0.substring(2)", "f0.substring(2)",
"SUBSTRING(f0, 2)", "SUBSTRING(f0, 2)",
"his is a test String.") "his is a test String.")
testFunction( testAllApis(
'f0.substring(2, 5), 'f0.substring(2, 5),
"f0.substring(2, 5)", "f0.substring(2, 5)",
"SUBSTRING(f0, 2, 5)", "SUBSTRING(f0, 2, 5)",
"his i") "his i")
testFunction( testAllApis(
'f0.substring(1, 'f7), 'f0.substring(1, 'f7),
"f0.substring(1, f7)", "f0.substring(1, f7)",
"SUBSTRING(f0, 1, f7)", "SUBSTRING(f0, 1, f7)",
...@@ -57,25 +55,25 @@ class ScalarFunctionsTest { ...@@ -57,25 +55,25 @@ class ScalarFunctionsTest {
@Test @Test
def testTrim(): Unit = { def testTrim(): Unit = {
testFunction( testAllApis(
'f8.trim(), 'f8.trim(),
"f8.trim()", "f8.trim()",
"TRIM(f8)", "TRIM(f8)",
"This is a test String.") "This is a test String.")
testFunction( testAllApis(
'f8.trim(removeLeading = true, removeTrailing = true, " "), 'f8.trim(removeLeading = true, removeTrailing = true, " "),
"trim(f8)", "trim(f8)",
"TRIM(f8)", "TRIM(f8)",
"This is a test String.") "This is a test String.")
testFunction( testAllApis(
'f8.trim(removeLeading = false, removeTrailing = true, " "), 'f8.trim(removeLeading = false, removeTrailing = true, " "),
"f8.trim(TRAILING, ' ')", "f8.trim(TRAILING, ' ')",
"TRIM(TRAILING FROM f8)", "TRIM(TRAILING FROM f8)",
" This is a test String.") " This is a test String.")
testFunction( testAllApis(
'f0.trim(removeLeading = true, removeTrailing = true, "."), 'f0.trim(removeLeading = true, removeTrailing = true, "."),
"trim(BOTH, '.', f0)", "trim(BOTH, '.', f0)",
"TRIM(BOTH '.' FROM f0)", "TRIM(BOTH '.' FROM f0)",
...@@ -84,13 +82,13 @@ class ScalarFunctionsTest { ...@@ -84,13 +82,13 @@ class ScalarFunctionsTest {
@Test @Test
def testCharLength(): Unit = { def testCharLength(): Unit = {
testFunction( testAllApis(
'f0.charLength(), 'f0.charLength(),
"f0.charLength()", "f0.charLength()",
"CHAR_LENGTH(f0)", "CHAR_LENGTH(f0)",
"22") "22")
testFunction( testAllApis(
'f0.charLength(), 'f0.charLength(),
"charLength(f0)", "charLength(f0)",
"CHARACTER_LENGTH(f0)", "CHARACTER_LENGTH(f0)",
...@@ -99,7 +97,7 @@ class ScalarFunctionsTest { ...@@ -99,7 +97,7 @@ class ScalarFunctionsTest {
@Test @Test
def testUpperCase(): Unit = { def testUpperCase(): Unit = {
testFunction( testAllApis(
'f0.upperCase(), 'f0.upperCase(),
"f0.upperCase()", "f0.upperCase()",
"UPPER(f0)", "UPPER(f0)",
...@@ -108,7 +106,7 @@ class ScalarFunctionsTest { ...@@ -108,7 +106,7 @@ class ScalarFunctionsTest {
@Test @Test
def testLowerCase(): Unit = { def testLowerCase(): Unit = {
testFunction( testAllApis(
'f0.lowerCase(), 'f0.lowerCase(),
"f0.lowerCase()", "f0.lowerCase()",
"LOWER(f0)", "LOWER(f0)",
...@@ -117,7 +115,7 @@ class ScalarFunctionsTest { ...@@ -117,7 +115,7 @@ class ScalarFunctionsTest {
@Test @Test
def testInitCap(): Unit = { def testInitCap(): Unit = {
testFunction( testAllApis(
'f0.initCap(), 'f0.initCap(),
"f0.initCap()", "f0.initCap()",
"INITCAP(f0)", "INITCAP(f0)",
...@@ -126,7 +124,7 @@ class ScalarFunctionsTest { ...@@ -126,7 +124,7 @@ class ScalarFunctionsTest {
@Test @Test
def testConcat(): Unit = { def testConcat(): Unit = {
testFunction( testAllApis(
'f0 + 'f0, 'f0 + 'f0,
"f0 + f0", "f0 + f0",
"f0||f0", "f0||f0",
...@@ -135,13 +133,13 @@ class ScalarFunctionsTest { ...@@ -135,13 +133,13 @@ class ScalarFunctionsTest {
@Test @Test
def testLike(): Unit = { def testLike(): Unit = {
testFunction( testAllApis(
'f0.like("Th_s%"), 'f0.like("Th_s%"),
"f0.like('Th_s%')", "f0.like('Th_s%')",
"f0 LIKE 'Th_s%'", "f0 LIKE 'Th_s%'",
"true") "true")
testFunction( testAllApis(
'f0.like("%is a%"), 'f0.like("%is a%"),
"f0.like('%is a%')", "f0.like('%is a%')",
"f0 LIKE '%is a%'", "f0 LIKE '%is a%'",
...@@ -150,13 +148,13 @@ class ScalarFunctionsTest { ...@@ -150,13 +148,13 @@ class ScalarFunctionsTest {
@Test @Test
def testNotLike(): Unit = { def testNotLike(): Unit = {
testFunction( testAllApis(
!'f0.like("Th_s%"), !'f0.like("Th_s%"),
"!f0.like('Th_s%')", "!f0.like('Th_s%')",
"f0 NOT LIKE 'Th_s%'", "f0 NOT LIKE 'Th_s%'",
"false") "false")
testFunction( testAllApis(
!'f0.like("%is a%"), !'f0.like("%is a%"),
"!f0.like('%is a%')", "!f0.like('%is a%')",
"f0 NOT LIKE '%is a%'", "f0 NOT LIKE '%is a%'",
...@@ -165,13 +163,13 @@ class ScalarFunctionsTest { ...@@ -165,13 +163,13 @@ class ScalarFunctionsTest {
@Test @Test
def testSimilar(): Unit = { def testSimilar(): Unit = {
testFunction( testAllApis(
'f0.similar("_*"), 'f0.similar("_*"),
"f0.similar('_*')", "f0.similar('_*')",
"f0 SIMILAR TO '_*'", "f0 SIMILAR TO '_*'",
"true") "true")
testFunction( testAllApis(
'f0.similar("This (is)? a (test)+ Strin_*"), 'f0.similar("This (is)? a (test)+ Strin_*"),
"f0.similar('This (is)? a (test)+ Strin_*')", "f0.similar('This (is)? a (test)+ Strin_*')",
"f0 SIMILAR TO 'This (is)? a (test)+ Strin_*'", "f0 SIMILAR TO 'This (is)? a (test)+ Strin_*'",
...@@ -180,13 +178,13 @@ class ScalarFunctionsTest { ...@@ -180,13 +178,13 @@ class ScalarFunctionsTest {
@Test @Test
def testNotSimilar(): Unit = { def testNotSimilar(): Unit = {
testFunction( testAllApis(
!'f0.similar("_*"), !'f0.similar("_*"),
"!f0.similar('_*')", "!f0.similar('_*')",
"f0 NOT SIMILAR TO '_*'", "f0 NOT SIMILAR TO '_*'",
"false") "false")
testFunction( testAllApis(
!'f0.similar("This (is)? a (test)+ Strin_*"), !'f0.similar("This (is)? a (test)+ Strin_*"),
"!f0.similar('This (is)? a (test)+ Strin_*')", "!f0.similar('This (is)? a (test)+ Strin_*')",
"f0 NOT SIMILAR TO 'This (is)? a (test)+ Strin_*'", "f0 NOT SIMILAR TO 'This (is)? a (test)+ Strin_*'",
...@@ -195,19 +193,19 @@ class ScalarFunctionsTest { ...@@ -195,19 +193,19 @@ class ScalarFunctionsTest {
@Test @Test
def testMod(): Unit = { def testMod(): Unit = {
testFunction( testAllApis(
'f4.mod('f7), 'f4.mod('f7),
"f4.mod(f7)", "f4.mod(f7)",
"MOD(f4, f7)", "MOD(f4, f7)",
"2") "2")
testFunction( testAllApis(
'f4.mod(3), 'f4.mod(3),
"mod(f4, 3)", "mod(f4, 3)",
"MOD(f4, 3)", "MOD(f4, 3)",
"2") "2")
testFunction( testAllApis(
'f4 % 3, 'f4 % 3,
"mod(44, 3)", "mod(44, 3)",
"MOD(44, 3)", "MOD(44, 3)",
...@@ -217,37 +215,43 @@ class ScalarFunctionsTest { ...@@ -217,37 +215,43 @@ class ScalarFunctionsTest {
@Test @Test
def testExp(): Unit = { def testExp(): Unit = {
testFunction( testAllApis(
'f2.exp(), 'f2.exp(),
"f2.exp()", "f2.exp()",
"EXP(f2)", "EXP(f2)",
math.exp(42.toByte).toString) math.exp(42.toByte).toString)
testFunction( testAllApis(
'f3.exp(), 'f3.exp(),
"f3.exp()", "f3.exp()",
"EXP(f3)", "EXP(f3)",
math.exp(43.toShort).toString) math.exp(43.toShort).toString)
testFunction( testAllApis(
'f4.exp(), 'f4.exp(),
"f4.exp()", "f4.exp()",
"EXP(f4)", "EXP(f4)",
math.exp(44.toLong).toString) math.exp(44.toLong).toString)
testFunction( testAllApis(
'f5.exp(), 'f5.exp(),
"f5.exp()", "f5.exp()",
"EXP(f5)", "EXP(f5)",
math.exp(4.5.toFloat).toString) math.exp(4.5.toFloat).toString)
testFunction( testAllApis(
'f6.exp(), 'f6.exp(),
"f6.exp()", "f6.exp()",
"EXP(f6)", "EXP(f6)",
math.exp(4.6).toString) math.exp(4.6).toString)
testFunction( testAllApis(
'f7.exp(),
"exp(3)",
"EXP(3)",
math.exp(3).toString)
testAllApis(
'f7.exp(), 'f7.exp(),
"exp(3)", "exp(3)",
"EXP(3)", "EXP(3)",
...@@ -256,31 +260,31 @@ class ScalarFunctionsTest { ...@@ -256,31 +260,31 @@ class ScalarFunctionsTest {
@Test @Test
def testLog10(): Unit = { def testLog10(): Unit = {
testFunction( testAllApis(
'f2.log10(), 'f2.log10(),
"f2.log10()", "f2.log10()",
"LOG10(f2)", "LOG10(f2)",
math.log10(42.toByte).toString) math.log10(42.toByte).toString)
testFunction( testAllApis(
'f3.log10(), 'f3.log10(),
"f3.log10()", "f3.log10()",
"LOG10(f3)", "LOG10(f3)",
math.log10(43.toShort).toString) math.log10(43.toShort).toString)
testFunction( testAllApis(
'f4.log10(), 'f4.log10(),
"f4.log10()", "f4.log10()",
"LOG10(f4)", "LOG10(f4)",
math.log10(44.toLong).toString) math.log10(44.toLong).toString)
testFunction( testAllApis(
'f5.log10(), 'f5.log10(),
"f5.log10()", "f5.log10()",
"LOG10(f5)", "LOG10(f5)",
math.log10(4.5.toFloat).toString) math.log10(4.5.toFloat).toString)
testFunction( testAllApis(
'f6.log10(), 'f6.log10(),
"f6.log10()", "f6.log10()",
"LOG10(f6)", "LOG10(f6)",
...@@ -289,19 +293,25 @@ class ScalarFunctionsTest { ...@@ -289,19 +293,25 @@ class ScalarFunctionsTest {
@Test @Test
def testPower(): Unit = { def testPower(): Unit = {
testFunction( testAllApis(
'f2.power('f7), 'f2.power('f7),
"f2.power(f7)", "f2.power(f7)",
"POWER(f2, f7)", "POWER(f2, f7)",
math.pow(42.toByte, 3).toString) math.pow(42.toByte, 3).toString)
testFunction( testAllApis(
'f3.power('f6), 'f3.power('f6),
"f3.power(f6)", "f3.power(f6)",
"POWER(f3, f6)", "POWER(f3, f6)",
math.pow(43.toShort, 4.6D).toString) math.pow(43.toShort, 4.6D).toString)
testFunction( testAllApis(
'f4.power('f5),
"f4.power(f5)",
"POWER(f4, f5)",
math.pow(44.toLong, 4.5.toFloat).toString)
testAllApis(
'f4.power('f5), 'f4.power('f5),
"f4.power(f5)", "f4.power(f5)",
"POWER(f4, f5)", "POWER(f4, f5)",
...@@ -310,31 +320,31 @@ class ScalarFunctionsTest { ...@@ -310,31 +320,31 @@ class ScalarFunctionsTest {
@Test @Test
def testLn(): Unit = { def testLn(): Unit = {
testFunction( testAllApis(
'f2.ln(), 'f2.ln(),
"f2.ln()", "f2.ln()",
"LN(f2)", "LN(f2)",
math.log(42.toByte).toString) math.log(42.toByte).toString)
testFunction( testAllApis(
'f3.ln(), 'f3.ln(),
"f3.ln()", "f3.ln()",
"LN(f3)", "LN(f3)",
math.log(43.toShort).toString) math.log(43.toShort).toString)
testFunction( testAllApis(
'f4.ln(), 'f4.ln(),
"f4.ln()", "f4.ln()",
"LN(f4)", "LN(f4)",
math.log(44.toLong).toString) math.log(44.toLong).toString)
testFunction( testAllApis(
'f5.ln(), 'f5.ln(),
"f5.ln()", "f5.ln()",
"LN(f5)", "LN(f5)",
math.log(4.5.toFloat).toString) math.log(4.5.toFloat).toString)
testFunction( testAllApis(
'f6.ln(), 'f6.ln(),
"f6.ln()", "f6.ln()",
"LN(f6)", "LN(f6)",
...@@ -343,102 +353,116 @@ class ScalarFunctionsTest { ...@@ -343,102 +353,116 @@ class ScalarFunctionsTest {
@Test @Test
def testAbs(): Unit = { def testAbs(): Unit = {
testFunction( testAllApis(
'f2.abs(), 'f2.abs(),
"f2.abs()", "f2.abs()",
"ABS(f2)", "ABS(f2)",
"42") "42")
testFunction( testAllApis(
'f3.abs(), 'f3.abs(),
"f3.abs()", "f3.abs()",
"ABS(f3)", "ABS(f3)",
"43") "43")
testFunction( testAllApis(
'f4.abs(), 'f4.abs(),
"f4.abs()", "f4.abs()",
"ABS(f4)", "ABS(f4)",
"44") "44")
testFunction( testAllApis(
'f5.abs(), 'f5.abs(),
"f5.abs()", "f5.abs()",
"ABS(f5)", "ABS(f5)",
"4.5") "4.5")
testFunction( testAllApis(
'f6.abs(), 'f6.abs(),
"f6.abs()", "f6.abs()",
"ABS(f6)", "ABS(f6)",
"4.6") "4.6")
testFunction( testAllApis(
'f9.abs(), 'f9.abs(),
"f9.abs()", "f9.abs()",
"ABS(f9)", "ABS(f9)",
"42") "42")
testFunction( testAllApis(
'f10.abs(), 'f10.abs(),
"f10.abs()", "f10.abs()",
"ABS(f10)", "ABS(f10)",
"43") "43")
testFunction( testAllApis(
'f11.abs(), 'f11.abs(),
"f11.abs()", "f11.abs()",
"ABS(f11)", "ABS(f11)",
"44") "44")
testFunction( testAllApis(
'f12.abs(), 'f12.abs(),
"f12.abs()", "f12.abs()",
"ABS(f12)", "ABS(f12)",
"4.5") "4.5")
testFunction( testAllApis(
'f13.abs(), 'f13.abs(),
"f13.abs()", "f13.abs()",
"ABS(f13)", "ABS(f13)",
"4.6") "4.6")
testAllApis(
'f15.abs(),
"f15.abs()",
"ABS(f15)",
"1231.1231231321321321111")
} }
@Test @Test
def testArithmeticFloorCeil(): Unit = { def testArithmeticFloorCeil(): Unit = {
testFunction( testAllApis(
'f5.floor(), 'f5.floor(),
"f5.floor()", "f5.floor()",
"FLOOR(f5)", "FLOOR(f5)",
"4.0") "4.0")
testFunction( testAllApis(
'f5.ceil(), 'f5.ceil(),
"f5.ceil()", "f5.ceil()",
"CEIL(f5)", "CEIL(f5)",
"5.0") "5.0")
testFunction( testAllApis(
'f3.floor(), 'f3.floor(),
"f3.floor()", "f3.floor()",
"FLOOR(f3)", "FLOOR(f3)",
"43") "43")
testFunction( testAllApis(
'f3.ceil(), 'f3.ceil(),
"f3.ceil()", "f3.ceil()",
"CEIL(f3)", "CEIL(f3)",
"43") "43")
testAllApis(
'f15.floor(),
"f15.floor()",
"FLOOR(f15)",
"-1232")
testAllApis(
'f15.ceil(),
"f15.ceil()",
"CEIL(f15)",
"-1231")
} }
// ---------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------
def testFunction( def testData = {
expr: Expression, val testData = new Row(16)
exprString: String,
sqlExpr: String,
expected: String): Unit = {
val testData = new Row(15)
testData.setField(0, "This is a test String.") testData.setField(0, "This is a test String.")
testData.setField(1, true) testData.setField(1, true)
testData.setField(2, 42.toByte) testData.setField(2, 42.toByte)
...@@ -454,8 +478,12 @@ class ScalarFunctionsTest { ...@@ -454,8 +478,12 @@ class ScalarFunctionsTest {
testData.setField(12, -4.5.toFloat) testData.setField(12, -4.5.toFloat)
testData.setField(13, -4.6) testData.setField(13, -4.6)
testData.setField(14, -3) testData.setField(14, -3)
testData.setField(15, BigDecimal("-1231.1231231321321321111").bigDecimal)
testData
}
val typeInfo = new RowTypeInfo(Seq( def typeInfo = {
new RowTypeInfo(Seq(
STRING_TYPE_INFO, STRING_TYPE_INFO,
BOOLEAN_TYPE_INFO, BOOLEAN_TYPE_INFO,
BYTE_TYPE_INFO, BYTE_TYPE_INFO,
...@@ -470,21 +498,7 @@ class ScalarFunctionsTest { ...@@ -470,21 +498,7 @@ class ScalarFunctionsTest {
LONG_TYPE_INFO, LONG_TYPE_INFO,
FLOAT_TYPE_INFO, FLOAT_TYPE_INFO,
DOUBLE_TYPE_INFO, DOUBLE_TYPE_INFO,
INT_TYPE_INFO)).asInstanceOf[TypeInformation[Any]] INT_TYPE_INFO,
BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]]
val exprResult = ExpressionEvaluator.evaluate(testData, typeInfo, expr)
assertEquals(expected, exprResult)
val exprStringResult = ExpressionEvaluator.evaluate(
testData,
typeInfo,
ExpressionParser.parseExpression(exprString))
assertEquals(expected, exprStringResult)
val exprSqlResult = ExpressionEvaluator.evaluate(testData, typeInfo, sqlExpr)
assertEquals(expected, exprSqlResult)
} }
} }
...@@ -20,7 +20,7 @@ package org.apache.flink.api.table.expressions.utils ...@@ -20,7 +20,7 @@ package org.apache.flink.api.table.expressions.utils
import org.apache.calcite.rel.logical.LogicalProject import org.apache.calcite.rel.logical.LogicalProject
import org.apache.calcite.rex.RexNode import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.tools.{Frameworks, RelBuilder} import org.apache.calcite.tools.{Frameworks, RelBuilder}
import org.apache.flink.api.common.functions.{Function, MapFunction} import org.apache.flink.api.common.functions.{Function, MapFunction}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
...@@ -28,25 +28,29 @@ import org.apache.flink.api.common.typeinfo.TypeInformation ...@@ -28,25 +28,29 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.{DataSet => JDataSet} import org.apache.flink.api.java.{DataSet => JDataSet}
import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction} import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction}
import org.apache.flink.api.table.expressions.Expression import org.apache.flink.api.table.expressions.{Expression, ExpressionParser}
import org.apache.flink.api.table.runtime.FunctionCompiler import org.apache.flink.api.table.runtime.FunctionCompiler
import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig, TableEnvironment} import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.{BatchTableEnvironment, Row, TableConfig, TableEnvironment}
import org.junit.Assert._
import org.junit.{After, Before}
import org.mockito.Mockito._ import org.mockito.Mockito._
import scala.collection.mutable
/** /**
* Utility to translate and evaluate an RexNode or Table API expression to a String. * Base test class for expression tests.
*/ */
object ExpressionEvaluator { abstract class ExpressionTestBase {
// TestCompiler that uses current class loader private val testExprs = mutable.LinkedHashSet[(RexNode, String)]()
class TestCompiler[T <: Function] extends FunctionCompiler[T] {
def compile(genFunc: GeneratedFunction[T]): Class[T] =
compile(getClass.getClassLoader, genFunc.name, genFunc.code)
}
private def prepareTable( // setup test utils
typeInfo: TypeInformation[Any]): (String, RelBuilder, TableEnvironment) = { private val tableName = "testTable"
private val context = prepareContext(typeInfo)
private val planner = Frameworks.getPlanner(context._2.getFrameworkConfig)
private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
// create DataSetTable // create DataSetTable
val dataSetMock = mock(classOf[DataSet[Any]]) val dataSetMock = mock(classOf[DataSet[Any]])
val jDataSetMock = mock(classOf[JDataSet[Any]]) val jDataSetMock = mock(classOf[JDataSet[Any]])
...@@ -55,65 +59,128 @@ object ExpressionEvaluator { ...@@ -55,65 +59,128 @@ object ExpressionEvaluator {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env) val tEnv = TableEnvironment.getTableEnvironment(env)
val tableName = "myTable"
tEnv.registerDataSet(tableName, dataSetMock) tEnv.registerDataSet(tableName, dataSetMock)
// prepare RelBuilder // prepare RelBuilder
val relBuilder = tEnv.getRelBuilder val relBuilder = tEnv.getRelBuilder
relBuilder.scan(tableName) relBuilder.scan(tableName)
(tableName, relBuilder, tEnv) (relBuilder, tEnv)
} }
def evaluate(data: Any, typeInfo: TypeInformation[Any], sqlExpr: String): String = { def testData: Any
// create DataSetTable
val table = prepareTable(typeInfo)
// create RelNode from SQL expression
val planner = Frameworks.getPlanner(table._3.getFrameworkConfig)
val parsed = planner.parse("SELECT " + sqlExpr + " FROM " + table._1)
val validated = planner.validate(parsed)
val converted = planner.rel(validated)
val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0)
evaluate(data, typeInfo, table._2, expr) def typeInfo: TypeInformation[Any]
}
def evaluate(data: Any, typeInfo: TypeInformation[Any], expr: Expression): String = { @Before
val table = prepareTable(typeInfo) def resetTestExprs() = {
val env = table._3 testExprs.clear()
val resolvedExpr =
env.asInstanceOf[BatchTableEnvironment].scan("myTable").select(expr).
getRelNode.asInstanceOf[LogicalProject].getChildExps.get(0)
evaluate(data, typeInfo, table._2, resolvedExpr)
} }
def evaluate( @After
data: Any, def evaluateExprs() = {
typeInfo: TypeInformation[Any], val relBuilder = context._1
relBuilder: RelBuilder,
rexNode: RexNode): String = {
// generate code for Mapper
val config = new TableConfig() val config = new TableConfig()
val generator = new CodeGenerator(config, false, typeInfo) val generator = new CodeGenerator(config, false, typeInfo)
val genExpr = generator.generateExpression(relBuilder.cast(rexNode, VARCHAR)) // cast to String
// cast expressions to String
val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, VARCHAR)).toSeq
// generate code
val resultType = new RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO))
val genExpr = generator.generateResultExpression(
resultType,
resultType.getFieldNames,
stringTestExprs)
val bodyCode = val bodyCode =
s""" s"""
|${genExpr.code} |${genExpr.code}
|return ${genExpr.resultTerm}; |return ${genExpr.resultTerm};
|""".stripMargin |""".stripMargin
val genFunc = generator.generateFunction[MapFunction[Any, String]]( val genFunc = generator.generateFunction[MapFunction[Any, String]](
"TestFunction", "TestFunction",
classOf[MapFunction[Any, String]], classOf[MapFunction[Any, String]],
bodyCode, bodyCode,
STRING_TYPE_INFO.asInstanceOf[TypeInformation[Any]]) resultType.asInstanceOf[TypeInformation[Any]])
// compile and evaluate // compile and evaluate
val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc) val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc)
val mapper = clazz.newInstance() val mapper = clazz.newInstance()
mapper.map(data) val result = mapper.map(testData).asInstanceOf[Row]
// compare
testExprs
.zipWithIndex
.foreach {
case ((expr, expected), index) =>
assertEquals(s"Wrong result for: $expr", expected, result.productElement(index))
}
} }
private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = {
// create RelNode from SQL expression
val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName")
val validated = planner.validate(parsed)
val converted = planner.rel(validated)
// extract RexNode
val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0)
testExprs.add((expr, expected))
planner.close()
}
private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): Unit = {
val env = context._2
val expr = env
.asInstanceOf[BatchTableEnvironment]
.scan(tableName)
.select(tableApiExpr)
.getRelNode
.asInstanceOf[LogicalProject]
.getChildExps
.get(0)
testExprs.add((expr, expected))
}
private def addTableApiTestExpr(tableApiString: String, expected: String): Unit = {
addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), expected)
}
def testAllApis(
expr: Expression,
exprString: String,
sqlExpr: String,
expected: String)
: Unit = {
addTableApiTestExpr(expr, expected)
addTableApiTestExpr(exprString, expected)
addSqlTestExpr(sqlExpr, expected)
}
def testTableApi(
expr: Expression,
exprString: String,
expected: String)
: Unit = {
addTableApiTestExpr(expr, expected)
addTableApiTestExpr(exprString, expected)
}
def testSqlApi(
sqlExpr: String,
expected: String)
: Unit = {
addSqlTestExpr(sqlExpr, expected)
}
// ----------------------------------------------------------------------------------------------
// TestCompiler that uses current class loader
class TestCompiler[T <: Function] extends FunctionCompiler[T] {
def compile(genFunc: GeneratedFunction[T]): Class[T] =
compile(getClass.getClassLoader, genFunc.name, genFunc.code)
}
} }
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.table.Row import org.apache.flink.api.table.Row
import org.junit.Test import org.junit.Test
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
...@@ -63,8 +64,13 @@ abstract class AggregateTestBase[T] { ...@@ -63,8 +64,13 @@ abstract class AggregateTestBase[T] {
finalAgg(rows) finalAgg(rows)
} }
assertEquals(expected, result) (expected, result) match {
case (e: BigDecimal, r: BigDecimal) =>
// BigDecimal.equals() value and scale but we are only interested in value.
assert(e.compareTo(r) == 0)
case _ =>
assertEquals(expected, result)
}
} }
} }
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]] private val numeric: Numeric[T] = implicitly[Numeric[T]]
...@@ -122,3 +124,31 @@ class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] { ...@@ -122,3 +124,31 @@ class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] {
override def aggregator = new DoubleAvgAggregate() override def aggregator = new DoubleAvgAggregate()
} }
class DecimalAvgAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("987654321000000"),
new BigDecimal("-0.000000000012345"),
null,
new BigDecimal("0.000000000012345"),
new BigDecimal("-987654321000000"),
null,
new BigDecimal("0")
),
Seq(
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
BigDecimal.ZERO,
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalAvgAggregate()
}
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]] private val numeric: Numeric[T] = implicitly[Numeric[T]]
...@@ -141,3 +143,35 @@ class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] { ...@@ -141,3 +143,35 @@ class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] {
override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate() override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate()
} }
class DecimalMaxAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("1000.000001"),
new BigDecimal("-1"),
new BigDecimal("-999.998999"),
null,
new BigDecimal("0"),
new BigDecimal("-999.999"),
null,
new BigDecimal("999.999")
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("1000.000001"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalMaxAggregate()
}
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]] private val numeric: Numeric[T] = implicitly[Numeric[T]]
...@@ -141,3 +143,35 @@ class BooleanMinAggregateTest extends AggregateTestBase[Boolean] { ...@@ -141,3 +143,35 @@ class BooleanMinAggregateTest extends AggregateTestBase[Boolean] {
override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate() override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate()
} }
class DecimalMinAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("1000"),
new BigDecimal("-1"),
new BigDecimal("-999.998999"),
null,
new BigDecimal("0"),
new BigDecimal("-999.999"),
null,
new BigDecimal("999.999")
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("-999.999"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalMinAggregate()
}
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]] private val numeric: Numeric[T] = implicitly[Numeric[T]]
...@@ -97,3 +99,39 @@ class DoubleSumAggregateTest extends SumAggregateTestBase[Double] { ...@@ -97,3 +99,39 @@ class DoubleSumAggregateTest extends SumAggregateTestBase[Double] {
override def aggregator: Aggregate[Double] = new DoubleSumAggregate override def aggregator: Aggregate[Double] = new DoubleSumAggregate
} }
class DecimalSumAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("2"),
new BigDecimal("3"),
null,
new BigDecimal("0"),
new BigDecimal("-1000"),
new BigDecimal("0.000000000002"),
new BigDecimal("1000"),
new BigDecimal("-0.000000000001"),
new BigDecimal("999.999"),
null,
new BigDecimal("4"),
new BigDecimal("-999.999"),
null
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("10.000000000001"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalSumAggregate()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册