提交 37defbb4 编写于 作者: T twalthr

[FLINK-3859] [table] Add BigDecimal/BigInteger support to Table API

This closes #2088.
上级 b6e6b818
......@@ -752,7 +752,7 @@ suffixed = cast | as | aggregation | nullCheck | evaluate | functionCall ;
cast = composite , ".cast(" , dataType , ")" ;
dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" ;
dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" | "DATE" | "DECIMAL";
as = composite , ".as(" , fieldReference , ")" ;
......@@ -773,6 +773,8 @@ nullLiteral = "Null(" , dataType , ")" ;
Here, `literal` is a valid Java literal, `fieldReference` specifies a column in the data, and `functionIdentifier` specifies a supported scalar function. The
column names and function names follow Java identifier syntax. Expressions specified as Strings can also use prefix notation instead of suffix notation to call operators and functions.
If working with exact numeric values or large decimals is required, the Table API also supports Java's BigDecimal type. In the Scala Table API decimals can be defined by `BigDecimal("123456")` and in Java by appending a "p" for precise e.g. `123456p`.
{% top %}
......
......@@ -220,6 +220,14 @@ trait ImplicitExpressionConversions {
def expr = Literal(l)
}
implicit class LiteralByteExpression(b: Byte) extends ImplicitExpressionOperations {
def expr = Literal(b)
}
implicit class LiteralShortExpression(s: Short) extends ImplicitExpressionOperations {
def expr = Literal(s)
}
implicit class LiteralIntExpression(i: Int) extends ImplicitExpressionOperations {
def expr = Literal(i)
}
......@@ -240,11 +248,26 @@ trait ImplicitExpressionConversions {
def expr = Literal(bool)
}
implicit class LiteralJavaDecimalExpression(javaDecimal: java.math.BigDecimal)
extends ImplicitExpressionOperations {
def expr = Literal(javaDecimal)
}
implicit class LiteralScalaDecimalExpression(scalaDecimal: scala.math.BigDecimal)
extends ImplicitExpressionOperations {
def expr = Literal(scalaDecimal.bigDecimal)
}
implicit def symbol2FieldExpression(sym: Symbol): Expression = UnresolvedFieldReference(sym.name)
implicit def byte2Literal(b: Byte): Expression = Literal(b)
implicit def short2Literal(s: Short): Expression = Literal(s)
implicit def int2Literal(i: Int): Expression = Literal(i)
implicit def long2Literal(l: Long): Expression = Literal(l)
implicit def double2Literal(d: Double): Expression = Literal(d)
implicit def float2Literal(d: Float): Expression = Literal(d)
implicit def string2Literal(str: String): Expression = Literal(str)
implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool)
implicit def javaDec2Literal(javaDec: java.math.BigDecimal): Expression = Literal(javaDec)
implicit def scalaDec2Literal(scalaDec: scala.math.BigDecimal): Expression =
Literal(scalaDec.bigDecimal)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.table
import org.apache.calcite.rel.`type`.RelDataTypeSystemImpl
/**
* Custom type system for Flink.
*/
class FlinkTypeSystem extends RelDataTypeSystemImpl {
// we cannot use Int.MaxValue because of an overflow in Calcites type inference logic
// half should be enough for all use cases
override def getMaxNumericScale: Int = Int.MaxValue / 2
// we cannot use Int.MaxValue because of an overflow in Calcites type inference logic
// half should be enough for all use cases
override def getMaxNumericPrecision: Int = Int.MaxValue / 2
}
......@@ -69,6 +69,7 @@ abstract class TableEnvironment(val config: TableConfig) {
.defaultSchema(tables)
.parserConfig(parserConfig)
.costFactory(new DataSetCostFactory)
.typeSystem(new FlinkTypeSystem)
.build
// the builder for Calcite RelNodes, Calcite's representation of a relational expression tree.
......
......@@ -23,11 +23,11 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.{FractionalTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{TypeExtractor, PojoTypeInfo, TupleTypeInfo}
import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo, TypeExtractor}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeCheckUtils}
object CodeGenUtils {
......@@ -97,11 +97,24 @@ object CodeGenUtils {
case _ => "null"
}
def requireNumeric(genExpr: GeneratedExpression) = genExpr.resultType match {
case nti: NumericTypeInfo[_] => // ok
case _ => throw new CodeGenException("Numeric expression type expected.")
def superPrimitive(typeInfo: TypeInformation[_]): String = typeInfo match {
case _: FractionalTypeInfo[_] => "double"
case _ => "long"
}
// ----------------------------------------------------------------------------------------------
def requireNumeric(genExpr: GeneratedExpression) =
if (!TypeCheckUtils.isNumeric(genExpr.resultType)) {
throw new CodeGenException("Numeric expression type expected, but was " +
s"'${genExpr.resultType}'.")
}
def requireComparable(genExpr: GeneratedExpression) =
if (!TypeCheckUtils.isComparable(genExpr.resultType)) {
throw new CodeGenException(s"Comparable type expected, but was '${genExpr.resultType}'.")
}
def requireString(genExpr: GeneratedExpression) = genExpr.resultType match {
case STRING_TYPE_INFO => // ok
case _ => throw new CodeGenException("String expression type expected.")
......@@ -112,6 +125,8 @@ object CodeGenUtils {
case _ => throw new CodeGenException("Boolean expression type expected.")
}
// ----------------------------------------------------------------------------------------------
def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)
def isReference(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
......@@ -126,27 +141,6 @@ object CodeGenUtils {
case _ => true
}
def isNumeric(genExpr: GeneratedExpression): Boolean = isNumeric(genExpr.resultType)
def isNumeric(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case nti: NumericTypeInfo[_] => true
case _ => false
}
def isString(genExpr: GeneratedExpression): Boolean = isString(genExpr.resultType)
def isString(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case STRING_TYPE_INFO => true
case _ => false
}
def isBoolean(genExpr: GeneratedExpression): Boolean = isBoolean(genExpr.resultType)
def isBoolean(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case BOOLEAN_TYPE_INFO => true
case _ => false
}
// ----------------------------------------------------------------------------------------------
sealed abstract class FieldAccessor
......
......@@ -18,10 +18,12 @@
package org.apache.flink.api.table.codegen
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.rex._
import org.apache.calcite.sql.{SqlLiteral, SqlOperator}
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.{SqlLiteral, SqlOperator}
import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
......@@ -32,8 +34,9 @@ import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.Indenter.toISC
import org.apache.flink.api.table.codegen.calls.ScalarFunctions
import org.apache.flink.api.table.codegen.calls.ScalarOperators._
import org.apache.flink.api.table.typeutils.{TypeConverter, RowTypeInfo}
import TypeConverter.sqlTypeToTypeInfo
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString}
import org.apache.flink.api.table.typeutils.TypeConverter.sqlTypeToTypeInfo
import scala.collection.JavaConversions._
import scala.collection.mutable
......@@ -542,7 +545,7 @@ class CodeGenerator(
case BOOLEAN =>
generateNonNullLiteral(resultType, literal.getValue3.toString)
case TINYINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidByte) {
generateNonNullLiteral(resultType, decimal.byteValue().toString)
}
......@@ -550,7 +553,7 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to byte.")
}
case SMALLINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidShort) {
generateNonNullLiteral(resultType, decimal.shortValue().toString)
}
......@@ -558,7 +561,7 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to short.")
}
case INTEGER =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidInt) {
generateNonNullLiteral(resultType, decimal.intValue().toString)
}
......@@ -566,29 +569,36 @@ class CodeGenerator(
throw new CodeGenException("Decimal can not be converted to integer.")
}
case BIGINT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
val decimal = BigDecimal(value.asInstanceOf[JBigDecimal])
if (decimal.isValidLong) {
generateNonNullLiteral(resultType, decimal.longValue().toString)
generateNonNullLiteral(resultType, decimal.longValue().toString + "L")
}
else {
throw new CodeGenException("Decimal can not be converted to long.")
}
case FLOAT =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
if (decimal.isValidFloat) {
generateNonNullLiteral(resultType, decimal.floatValue().toString + "f")
}
else {
throw new CodeGenException("Decimal can not be converted to float.")
val floatValue = value.asInstanceOf[JBigDecimal].floatValue()
floatValue match {
case Float.NaN => generateNonNullLiteral(resultType, "java.lang.Float.NaN")
case Float.NegativeInfinity =>
generateNonNullLiteral(resultType, "java.lang.Float.NEGATIVE_INFINITY")
case Float.PositiveInfinity =>
generateNonNullLiteral(resultType, "java.lang.Float.POSITIVE_INFINITY")
case _ => generateNonNullLiteral(resultType, floatValue.toString + "f")
}
case DOUBLE =>
val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal])
if (decimal.isValidDouble) {
generateNonNullLiteral(resultType, decimal.doubleValue().toString)
}
else {
throw new CodeGenException("Decimal can not be converted to double.")
val doubleValue = value.asInstanceOf[JBigDecimal].doubleValue()
doubleValue match {
case Double.NaN => generateNonNullLiteral(resultType, "java.lang.Double.NaN")
case Double.NegativeInfinity =>
generateNonNullLiteral(resultType, "java.lang.Double.NEGATIVE_INFINITY")
case Double.PositiveInfinity =>
generateNonNullLiteral(resultType, "java.lang.Double.POSITIVE_INFINITY")
case _ => generateNonNullLiteral(resultType, doubleValue.toString + "d")
}
case DECIMAL =>
val decimalField = addReusableDecimal(value.asInstanceOf[JBigDecimal])
generateNonNullLiteral(resultType, decimalField)
case VARCHAR | CHAR =>
generateNonNullLiteral(resultType, "\"" + value.toString + "\"")
case SYMBOL =>
......@@ -630,7 +640,7 @@ class CodeGenerator(
val left = operands.head
val right = operands(1)
requireString(left)
generateArithmeticOperator("+", nullCheck, resultType, left, right)
generateStringConcatOperator(nullCheck, left, right)
case MINUS if isNumeric(resultType) =>
val left = operands.head
......@@ -674,37 +684,39 @@ class CodeGenerator(
case EQUALS =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
generateEquals(nullCheck, left, right)
case NOT_EQUALS =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
generateNotEquals(nullCheck, left, right)
case GREATER_THAN =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
requireComparable(left)
requireComparable(right)
generateComparison(">", nullCheck, left, right)
case GREATER_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
requireComparable(left)
requireComparable(right)
generateComparison(">=", nullCheck, left, right)
case LESS_THAN =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
requireComparable(left)
requireComparable(right)
generateComparison("<", nullCheck, left, right)
case LESS_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
checkNumericOrString(left, right)
requireComparable(left)
requireComparable(right)
generateComparison("<=", nullCheck, left, right)
case IS_NULL =>
......@@ -775,14 +787,6 @@ class CodeGenerator(
// generator helping methods
// ----------------------------------------------------------------------------------------------
def checkNumericOrString(left: GeneratedExpression, right: GeneratedExpression): Unit = {
if (isNumeric(left)) {
requireNumeric(right)
} else if (isString(left)) {
requireString(right)
}
}
private def generateInputAccess(
inputType: TypeInformation[Any],
inputTerm: String,
......@@ -1036,4 +1040,18 @@ class CodeGenerator(
fieldTerm
}
def addReusableDecimal(decimal: JBigDecimal): String = decimal match {
case JBigDecimal.ZERO => "java.math.BigDecimal.ZERO"
case JBigDecimal.ONE => "java.math.BigDecimal.ONE"
case JBigDecimal.TEN => "java.math.BigDecimal.TEN"
case _ =>
val fieldTerm = newName("decimal")
val fieldDecimal =
s"""
|transient java.math.BigDecimal $fieldTerm =
| new java.math.BigDecimal("${decimal.toString}");
|""".stripMargin
reusableMemberStatements.add(fieldDecimal)
fieldTerm
}
}
......@@ -17,6 +17,8 @@
*/
package org.apache.flink.api.table.codegen.calls
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.linq4j.tree.Types
import org.apache.calcite.runtime.SqlFunctions
......@@ -26,4 +28,5 @@ object BuiltInMethods {
val POWER = Types.lookupMethod(classOf[Math], "pow", classOf[Double], classOf[Double])
val LN = Types.lookupMethod(classOf[Math], "log", classOf[Double])
val ABS = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[Double])
val ABS_DEC = Types.lookupMethod(classOf[SqlFunctions], "abs", classOf[JBigDecimal])
}
......@@ -20,7 +20,8 @@ package org.apache.flink.api.table.codegen.calls
import java.lang.reflect.Method
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.
{DOUBLE_TYPE_INFO, FLOAT_TYPE_INFO,BIG_DEC_TYPE_INFO}
import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression}
/**
......@@ -33,7 +34,7 @@ class FloorCeilCallGen(method: Method) extends MultiTypeMethodCallGen(method) {
operands: Seq[GeneratedExpression])
: GeneratedExpression = {
operands.head.resultType match {
case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO =>
case FLOAT_TYPE_INFO | DOUBLE_TYPE_INFO | BIG_DEC_TYPE_INFO =>
super.generate(codeGenerator, operands)
case _ =>
operands.head // no floor/ceil necessary
......
......@@ -142,16 +142,31 @@ object ScalarFunctions {
Seq(DOUBLE_TYPE_INFO),
new MultiTypeMethodCallGen(BuiltInMethods.ABS))
addSqlFunction(
ABS,
Seq(BIG_DEC_TYPE_INFO),
new MultiTypeMethodCallGen(BuiltInMethods.ABS_DEC))
addSqlFunction(
FLOOR,
Seq(DOUBLE_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
addSqlFunction(
FLOOR,
Seq(BIG_DEC_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.FLOOR.method))
addSqlFunction(
CEIL,
Seq(DOUBLE_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.CEIL.method))
addSqlFunction(
CEIL,
Seq(BIG_DEC_TYPE_INFO),
new FloorCeilCallGen(BuiltInMethod.CEIL.method))
// ----------------------------------------------------------------------------------------------
......
......@@ -18,12 +18,23 @@
package org.apache.flink.api.table.codegen.calls
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.codegen.CodeGenUtils._
import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression}
import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isBoolean, isComparable, isDecimal, isNumeric}
object ScalarOperators {
def generateStringConcatOperator(
nullCheck: Boolean,
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, STRING_TYPE_INFO, left, right) {
(leftTerm, rightTerm) => s"$leftTerm + $rightTerm"
}
}
def generateArithmeticOperator(
operator: String,
nullCheck: Boolean,
......@@ -31,40 +42,17 @@ object ScalarOperators {
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
// String arithmetic // TODO rework
if (isString(left)) {
generateOperatorIfNotNull(nullCheck, resultType, left, right) {
(leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
}
}
// Numeric arithmetic
else if (isNumeric(left) && isNumeric(right)) {
val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]]
val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]]
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
val leftCasting = numericCasting(left.resultType, resultType)
val rightCasting = numericCasting(right.resultType, resultType)
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
generateOperatorIfNotNull(nullCheck, resultType, left, right) {
generateOperatorIfNotNull(nullCheck, resultType, left, right) {
(leftTerm, rightTerm) =>
// no casting required
if (leftType == resultType && rightType == resultType) {
s"$leftTerm $operator $rightTerm"
}
// left needs casting
else if (leftType != resultType && rightType == resultType) {
s"(($resultTypeTerm) $leftTerm) $operator $rightTerm"
}
// right needs casting
else if (leftType == resultType && rightType != resultType) {
s"$leftTerm $operator (($resultTypeTerm) $rightTerm)"
}
// both sides need casting
else {
s"(($resultTypeTerm) $leftTerm) $operator (($resultTypeTerm) $rightTerm)"
if (isDecimal(resultType)) {
s"${leftCasting(leftTerm)}.${arithOpToDecMethod(operator)}(${rightCasting(rightTerm)})"
} else {
s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})"
}
}
}
else {
throw new CodeGenException("Unsupported arithmetic operation.")
}
}
......@@ -75,7 +63,16 @@ object ScalarOperators {
operand: GeneratedExpression)
: GeneratedExpression = {
generateUnaryOperatorIfNotNull(nullCheck, resultType, operand) {
(operandTerm) => s"$operator($operandTerm)"
(operandTerm) =>
if (isDecimal(operand.resultType) && operator == "-") {
s"$operandTerm.negate()"
} else if (isDecimal(operand.resultType) && operator == "+") {
s"$operandTerm"
} else if (isNumeric(operand.resultType)) {
s"$operator($operandTerm)"
} else {
throw new CodeGenException("Unsupported unary operator.")
}
}
}
......@@ -84,15 +81,27 @@ object ScalarOperators {
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)"
}
else {
(leftTerm, rightTerm) => s"$leftTerm == $rightTerm"
// numeric types
if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
generateComparison("==", nullCheck, left, right)
}
// comparable types of same type
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
generateComparison("==", nullCheck, left, right)
}
// non comparable types
else {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)"
}
else {
throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
}
}
}
}
......@@ -102,19 +111,34 @@ object ScalarOperators {
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))"
}
else {
(leftTerm, rightTerm) => s"$leftTerm != $rightTerm"
// numeric types
if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
generateComparison("!=", nullCheck, left, right)
}
// comparable types
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
generateComparison("!=", nullCheck, left, right)
}
// non comparable types
else {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isReference(left)) {
(leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))"
}
else if (isReference(right)) {
(leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))"
}
else {
throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
}
}
}
}
/**
* Generates comparison code for numeric types and comparable types of same type.
*/
def generateComparison(
operator: String,
nullCheck: Boolean,
......@@ -122,14 +146,38 @@ object ScalarOperators {
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
if (isString(left) && isString(right)) {
(leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0"
// left is decimal or both sides are decimal
if (isDecimal(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => {
val operandCasting = numericCasting(right.resultType, left.resultType)
s"$leftTerm.compareTo(${operandCasting(rightTerm)}) $operator 0"
}
}
// right is decimal
else if (isNumeric(left.resultType) && isDecimal(right.resultType)) {
(leftTerm, rightTerm) => {
val operandCasting = numericCasting(left.resultType, right.resultType)
s"${operandCasting(leftTerm)}.compareTo($rightTerm) $operator 0"
}
}
else if (isNumeric(left) && isNumeric(right)) {
// both sides are numeric
else if (isNumeric(left.resultType) && isNumeric(right.resultType)) {
(leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
}
// both sides are boolean
else if (isBoolean(left.resultType) && left.resultType == right.resultType) {
operator match {
case "==" | "!=" => (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
case _ => throw new CodeGenException(s"Unsupported boolean comparison '$operator'.")
}
}
// both sides are same comparable type
else if (isComparable(left.resultType) && left.resultType == right.resultType) {
(leftTerm, rightTerm) => s"$leftTerm.compareTo($rightTerm) $operator 0"
}
else {
throw new CodeGenException("Comparison is only supported for Strings and numeric types.")
throw new CodeGenException(s"Incomparable types: ${left.resultType} and " +
s"${right.resultType}")
}
}
}
......@@ -147,7 +195,7 @@ object ScalarOperators {
|boolean $nullTerm = false;
|""".stripMargin
}
else if (!nullCheck && isReference(operand.resultType)) {
else if (!nullCheck && isReference(operand)) {
s"""
|${operand.code}
|boolean $resultTerm = ${operand.resultTerm} == null;
......@@ -177,7 +225,7 @@ object ScalarOperators {
|boolean $nullTerm = false;
|""".stripMargin
}
else if (!nullCheck && isReference(operand.resultType)) {
else if (!nullCheck && isReference(operand)) {
s"""
|${operand.code}
|boolean $resultTerm = ${operand.resultTerm} != null;
......@@ -326,63 +374,72 @@ object ScalarOperators {
nullCheck: Boolean,
operand: GeneratedExpression,
targetType: TypeInformation[_])
: GeneratedExpression = {
targetType match {
// identity casting
case operand.resultType =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm"
}
: GeneratedExpression = (operand.resultType, targetType) match {
// identity casting
case (fromTp, toTp) if fromTp == toTp =>
operand
// * -> String
case (_, STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s""" "" + $operandTerm"""
}
// * -> String
case STRING_TYPE_INFO =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s""" "" + $operandTerm"""
}
// * -> Character
case (_, CHAR_TYPE_INFO) =>
throw new CodeGenException("Character type not supported.")
// * -> Date
case DATE_TYPE_INFO =>
throw new CodeGenException("Date type not supported yet.")
// String -> NUMERIC TYPE (not Character), Boolean
case (STRING_TYPE_INFO, _: NumericTypeInfo[_])
| (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) =>
val wrapperClass = targetType.getTypeClass.getCanonicalName
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$wrapperClass.valueOf($operandTerm)"
}
// * -> Void
case VOID_TYPE_INFO =>
throw new CodeGenException("Void type not supported.")
// String -> BigDecimal
case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) =>
val wrapperClass = targetType.getTypeClass.getCanonicalName
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"new $wrapperClass($operandTerm)"
}
// * -> Character
case CHAR_TYPE_INFO =>
throw new CodeGenException("Character type not supported.")
// Boolean -> NUMERIC TYPE
case (BOOLEAN_TYPE_INFO, nti: NumericTypeInfo[_]) =>
val targetTypeTerm = primitiveTypeTermForTypeInfo(nti)
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)"
}
// NUMERIC TYPE -> Boolean
case BOOLEAN_TYPE_INFO if isNumeric(operand) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm != 0"
}
// Boolean -> BigDecimal
case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm ? java.math.BigDecimal.ONE : java.math.BigDecimal.ZERO"
}
// String -> BASIC TYPE (not String, Date, Void, Character)
case ti: BasicTypeInfo[_] if isString(operand) =>
val wrapperClass = targetType.getTypeClass.getCanonicalName
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$wrapperClass.valueOf($operandTerm)"
}
// NUMERIC TYPE -> Boolean
case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm != 0"
}
// NUMERIC TYPE -> NUMERIC TYPE
case nti: NumericTypeInfo[_] if isNumeric(operand) =>
val targetTypeTerm = primitiveTypeTermForTypeInfo(nti)
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"($targetTypeTerm) $operandTerm"
}
// BigDecimal -> Boolean
case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"$operandTerm.compareTo(java.math.BigDecimal.ZERO) != 0"
}
// Boolean -> NUMERIC TYPE
case nti: NumericTypeInfo[_] if isBoolean(operand) =>
val targetTypeTerm = primitiveTypeTermForTypeInfo(nti)
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)"
}
// NUMERIC TYPE, BigDecimal -> NUMERIC TYPE, BigDecimal
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_])
| (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_])
| (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) =>
val operandCasting = numericCasting(operand.resultType, targetType)
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"${operandCasting(operandTerm)}"
}
case _ =>
throw new CodeGenException(s"Unsupported cast from '${operand.resultType}'" +
s"to '$targetType'.")
}
case (from, to) =>
throw new CodeGenException(s"Unsupported cast from '$from' to '$to'.")
}
def generateIfElse(
......@@ -519,4 +576,51 @@ object ScalarOperators {
GeneratedExpression(resultTerm, nullTerm, resultCode, resultType)
}
private def arithOpToDecMethod(operator: String): String = operator match {
case "+" => "add"
case "-" => "subtract"
case "*" => "multiply"
case "/" => "divide"
case "%" => "remainder"
case _ => throw new CodeGenException("Unsupported decimal arithmetic operator.")
}
private def numericCasting(
operandType: TypeInformation[_],
resultType: TypeInformation[_])
: (String) => String = {
def decToPrimMethod(targetType: TypeInformation[_]): String = targetType match {
case BYTE_TYPE_INFO => "byteValueExact"
case SHORT_TYPE_INFO => "shortValueExact"
case INT_TYPE_INFO => "intValueExact"
case LONG_TYPE_INFO => "longValueExact"
case FLOAT_TYPE_INFO => "floatValue"
case DOUBLE_TYPE_INFO => "doubleValue"
case _ => throw new CodeGenException("Unsupported decimal casting type.")
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
// no casting necessary
if (operandType == resultType) {
(operandTerm) => s"$operandTerm"
}
// result type is decimal but numeric operand is not
else if (isDecimal(resultType) && !isDecimal(operandType) && isNumeric(operandType)) {
(operandTerm) =>
s"java.math.BigDecimal.valueOf((${superPrimitive(operandType)}) $operandTerm)"
}
// numeric result type is not decimal but operand is
else if (isNumeric(resultType) && !isDecimal(resultType) && isDecimal(operandType) ) {
(operandTerm) => s"$operandTerm.${decToPrimMethod(resultType)}()"
}
// result type and operand type are numeric but not decimal
else if (isNumeric(operandType) && isNumeric(resultType)
&& !isDecimal(operandType) && !isDecimal(resultType)) {
(operandTerm) => s"(($resultTypeTerm) $operandTerm)"
}
else {
throw new CodeGenException(s"Unsupported casting from $operandType to $resultType.")
}
}
}
......@@ -71,23 +71,24 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
"DOUBLE" ^^ { ti => BasicTypeInfo.DOUBLE_TYPE_INFO } |
("BOOL" | "BOOLEAN" ) ^^ { ti => BasicTypeInfo.BOOLEAN_TYPE_INFO } |
"STRING" ^^ { ti => BasicTypeInfo.STRING_TYPE_INFO } |
"DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO }
"DATE" ^^ { ti => BasicTypeInfo.DATE_TYPE_INFO } |
"DECIMAL" ^^ { ti => BasicTypeInfo.BIG_DEC_TYPE_INFO }
// Literals
lazy val numberLiteral: PackratParser[Expression] =
((wholeNumber <~ ("L" | "l")) | floatingPointNumber | decimalNumber | wholeNumber) ^^ {
str =>
if (str.endsWith("L") || str.endsWith("l")) {
Literal(str.toLong)
} else if (str.matches("""-?\d+""")) {
Literal(str.toInt)
} else if (str.endsWith("f") | str.endsWith("F")) {
Literal(str.toFloat)
} else {
Literal(str.toDouble)
}
}
(wholeNumber <~ ("l" | "L")) ^^ { n => Literal(n.toLong) } |
(decimalNumber <~ ("p" | "P")) ^^ { n => Literal(BigDecimal(n)) } |
(floatingPointNumber | decimalNumber) ^^ {
n =>
if (n.matches("""-?\d+""")) {
Literal(n.toInt)
} else if (n.endsWith("f") || n.endsWith("F")) {
Literal(n.toFloat)
} else {
Literal(n.toDouble)
}
}
lazy val singleQuoteStringLiteral: Parser[Expression] =
("'" + """([^'\p{Cntrl}\\]|\\[\\'"bfnrt]|\\u[a-fA-F0-9]{4})*""" + "'").r ^^ {
......@@ -261,7 +262,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val unaryMinus: PackratParser[Expression] = "-" ~> composite ^^ { e => UnaryMinus(e) }
lazy val unary = unaryNot | unaryMinus | composite
lazy val unary = composite | unaryNot | unaryMinus
// arithmetic
......
......@@ -17,18 +17,17 @@
*/
package org.apache.flink.api.table.expressions
import scala.collection.JavaConversions._
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion, TypeConverter}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isNumeric, isString}
import org.apache.flink.api.table.typeutils.{TypeCheckUtils, TypeCoercion}
import org.apache.flink.api.table.validate._
import scala.collection.JavaConversions._
abstract class BinaryArithmetic extends BinaryExpression {
def sqlOperator: SqlOperator
......@@ -45,9 +44,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
// TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = {
if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] ||
!right.resultType.isInstanceOf[NumericTypeInfo[_]]) {
ValidationFailure(s"$this requires both operands Numeric, get" +
if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
ValidationFailure(s"$this requires both operands Numeric, got " +
s"${left.resultType} and ${right.resultType}")
} else {
ValidationSuccess
......@@ -61,28 +59,24 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
val sqlOperator = SqlStdOperatorTable.PLUS
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val l = left.toRexNode
val r = right.toRexNode
if(SqlTypeName.STRING_TYPES.contains(l.getType.getSqlTypeName)) {
val cast: RexNode = relBuilder.cast(r,
TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO))
relBuilder.call(SqlStdOperatorTable.PLUS, l, cast)
} else if(SqlTypeName.STRING_TYPES.contains(r.getType.getSqlTypeName)) {
val cast: RexNode = relBuilder.cast(l,
TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO))
relBuilder.call(SqlStdOperatorTable.PLUS, cast, r)
if(isString(left.resultType)) {
val castedRight = Cast(right, BasicTypeInfo.STRING_TYPE_INFO)
relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, castedRight.toRexNode)
} else if(isString(right.resultType)) {
val castedLeft = Cast(left, BasicTypeInfo.STRING_TYPE_INFO)
relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, right.toRexNode)
} else {
relBuilder.call(SqlStdOperatorTable.PLUS, l, r)
val castedLeft = Cast(left, resultType)
val castedRight = Cast(right, resultType)
relBuilder.call(SqlStdOperatorTable.PLUS, castedLeft.toRexNode, castedRight.toRexNode)
}
}
// TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = {
if (left.resultType == BasicTypeInfo.STRING_TYPE_INFO ||
right.resultType == BasicTypeInfo.STRING_TYPE_INFO) {
if (isString(left.resultType) || isString(right.resultType)) {
ValidationSuccess
} else if (!left.resultType.isInstanceOf[NumericTypeInfo[_]] ||
!right.resultType.isInstanceOf[NumericTypeInfo[_]]) {
} else if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
ValidationFailure(s"$this requires Numeric or String input," +
s" get ${left.resultType} and ${right.resultType}")
} else {
......
......@@ -17,17 +17,16 @@
*/
package org.apache.flink.api.table.expressions
import scala.collection.JavaConversions._
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.NumericTypeInfo
import org.apache.flink.api.table.typeutils.TypeCheckUtils.{isComparable, isNumeric}
import org.apache.flink.api.table.validate._
import scala.collection.JavaConversions._
abstract class BinaryComparison extends BinaryExpression {
def sqlOperator: SqlOperator
......@@ -39,11 +38,12 @@ abstract class BinaryComparison extends BinaryExpression {
// TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (STRING_TYPE_INFO, STRING_TYPE_INFO) => ValidationSuccess
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
case (lType, rType) if isComparable(lType) && lType == rType => ValidationSuccess
case (lType, rType) =>
ValidationFailure(
s"Comparison is only supported for Strings and numeric types, get $lType and $rType")
s"Comparison is only supported for numeric types and comparable types of same type," +
s"got $lType and $rType")
}
}
......@@ -53,13 +53,11 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val sqlOperator: SqlOperator = SqlStdOperatorTable.EQUALS
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
// TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) =>
if (lType != rType) {
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
} else {
ValidationSuccess
}
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
}
}
......@@ -69,13 +67,11 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari
val sqlOperator: SqlOperator = SqlStdOperatorTable.NOT_EQUALS
override def validateInput(): ExprValidationResult = (left.resultType, right.resultType) match {
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => ValidationSuccess
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
// TODO widen this rule once we support custom objects as types (FLINK-3916)
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) =>
if (lType != rType) {
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
} else {
ValidationSuccess
}
ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType")
}
}
......
......@@ -20,6 +20,7 @@ package org.apache.flink.api.table.expressions
import java.util.Date
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.table.typeutils.TypeConverter
......@@ -35,6 +36,9 @@ object Literal {
case str: String => Literal(str, BasicTypeInfo.STRING_TYPE_INFO)
case bool: Boolean => Literal(bool, BasicTypeInfo.BOOLEAN_TYPE_INFO)
case date: Date => Literal(date, BasicTypeInfo.DATE_TYPE_INFO)
case javaDec: java.math.BigDecimal => Literal(javaDec, BasicTypeInfo.BIG_DEC_TYPE_INFO)
case scalaDec: scala.math.BigDecimal =>
Literal(scalaDec.bigDecimal, BasicTypeInfo.BIG_DEC_TYPE_INFO)
}
}
......@@ -42,7 +46,13 @@ case class Literal(value: Any, resultType: TypeInformation[_]) extends LeafExpre
override def toString = s"$value"
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.literal(value)
resultType match {
case BasicTypeInfo.BIG_DEC_TYPE_INFO =>
val bigDecValue = value.asInstanceOf[java.math.BigDecimal]
val decType = relBuilder.getTypeFactory.createSqlType(SqlTypeName.DECIMAL)
relBuilder.getRexBuilder.makeExactLiteral(bigDecValue, decType)
case _ => relBuilder.literal(value)
}
}
}
......
......@@ -61,6 +61,7 @@ trait DataSetRel extends RelNode with FlinkRel {
case SqlTypeName.DOUBLE => s + 8
case SqlTypeName.VARCHAR => s + 12
case SqlTypeName.CHAR => s + 1
case SqlTypeName.DECIMAL => s + 12
case _ => throw new TableException("Unsupported data type encountered")
}
}
......
......@@ -155,6 +155,8 @@ object AggregateUtil {
new FloatSumAggregate
case DOUBLE =>
new DoubleSumAggregate
case DECIMAL =>
new DecimalSumAggregate
case sqlType: SqlTypeName =>
throw new TableException("Sum aggregate does no support type:" + sqlType)
}
......@@ -173,6 +175,8 @@ object AggregateUtil {
new FloatAvgAggregate
case DOUBLE =>
new DoubleAvgAggregate
case DECIMAL =>
new DecimalAvgAggregate
case sqlType: SqlTypeName =>
throw new TableException("Avg aggregate does no support type:" + sqlType)
}
......@@ -192,6 +196,8 @@ object AggregateUtil {
new FloatMinAggregate
case DOUBLE =>
new DoubleMinAggregate
case DECIMAL =>
new DecimalMinAggregate
case BOOLEAN =>
new BooleanMinAggregate
case sqlType: SqlTypeName =>
......@@ -211,6 +217,8 @@ object AggregateUtil {
new FloatMaxAggregate
case DOUBLE =>
new DoubleMaxAggregate
case DECIMAL =>
new DecimalMaxAggregate
case BOOLEAN =>
new BooleanMaxAggregate
case sqlType: SqlTypeName =>
......
......@@ -18,8 +18,9 @@
package org.apache.flink.api.table.runtime.aggregate
import com.google.common.math.LongMath
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
import java.math.BigDecimal
import java.math.BigInteger
abstract class AvgAggregate[T] extends Aggregate[T] {
......@@ -251,3 +252,45 @@ class DoubleAvgAggregate extends FloatingAvgAggregate[Double] {
}
}
}
class DecimalAvgAggregate extends AvgAggregate[BigDecimal] {
override def intermediateDataType = Array(
BasicTypeInfo.BIG_DEC_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO)
override def initiate(partial: Row): Unit = {
partial.setField(partialSumIndex, BigDecimal.ZERO)
partial.setField(partialCountIndex, 0L)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
val input = value.asInstanceOf[BigDecimal]
partial.setField(partialSumIndex, input)
partial.setField(partialCountIndex, 1L)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialSum = partial.productElement(partialSumIndex).asInstanceOf[BigDecimal]
val partialCount = partial.productElement(partialCountIndex).asInstanceOf[Long]
val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal]
val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
buffer.setField(partialSumIndex, partialSum.add(bufferSum))
buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
}
override def evaluate(buffer: Row): BigDecimal = {
val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
if (bufferCount != 0) {
val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigDecimal]
bufferSum.divide(BigDecimal.valueOf(bufferCount))
} else {
null.asInstanceOf[BigDecimal]
}
}
}
......@@ -17,6 +17,8 @@
*/
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
......@@ -125,3 +127,45 @@ class BooleanMaxAggregate extends MaxAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
class DecimalMaxAggregate extends Aggregate[BigDecimal] {
protected var minIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(minIndex, null)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
partial.setField(minIndex, value)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
val min = if (partialValue.compareTo(bufferValue) > 0) partialValue else bufferValue
buffer.setField(minIndex, min)
} else {
buffer.setField(minIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(minIndex).asInstanceOf[BigDecimal]
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
minIndex = aggOffset
}
}
......@@ -17,6 +17,7 @@
*/
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
......@@ -125,3 +126,45 @@ class BooleanMinAggregate extends MinAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
}
class DecimalMinAggregate extends Aggregate[BigDecimal] {
protected var minIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(minIndex, null)
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
partial.setField(minIndex, value)
}
}
override def merge(partial: Row, buffer: Row): Unit = {
val partialValue = partial.productElement(minIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(minIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
val min = if (partialValue.compareTo(bufferValue) < 0) partialValue else bufferValue
buffer.setField(minIndex, min)
} else {
buffer.setField(minIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(minIndex).asInstanceOf[BigDecimal]
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
minIndex = aggOffset
}
}
......@@ -17,6 +17,7 @@
*/
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
......@@ -85,3 +86,45 @@ class FloatSumAggregate extends SumAggregate[Float] {
class DoubleSumAggregate extends SumAggregate[Double] {
override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
}
class DecimalSumAggregate extends Aggregate[BigDecimal] {
protected var sumIndex: Int = _
override def intermediateDataType = Array(BasicTypeInfo.BIG_DEC_TYPE_INFO)
override def initiate(partial: Row): Unit = {
partial.setField(sumIndex, null)
}
override def merge(partial1: Row, buffer: Row): Unit = {
val partialValue = partial1.productElement(sumIndex).asInstanceOf[BigDecimal]
if (partialValue != null) {
val bufferValue = buffer.productElement(sumIndex).asInstanceOf[BigDecimal]
if (bufferValue != null) {
buffer.setField(sumIndex, partialValue.add(bufferValue))
} else {
buffer.setField(sumIndex, partialValue)
}
}
}
override def evaluate(buffer: Row): BigDecimal = {
buffer.productElement(sumIndex).asInstanceOf[BigDecimal]
}
override def prepare(value: Any, partial: Row): Unit = {
if (value == null) {
initiate(partial)
} else {
val input = value.asInstanceOf[BigDecimal]
partial.setField(sumIndex, input)
}
}
override def supportPartial: Boolean = true
override def setAggOffsetInRow(aggOffset: Int): Unit = {
sumIndex = aggOffset
}
}
......@@ -17,17 +17,37 @@
*/
package org.apache.flink.api.table.typeutils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.validate._
object TypeCheckUtils {
def assertNumericExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = {
if (dataType.isInstanceOf[NumericTypeInfo[_]]) {
def isNumeric(dataType: TypeInformation[_]): Boolean = dataType match {
case _: NumericTypeInfo[_] => true
case BIG_DEC_TYPE_INFO => true
case _ => false
}
def isString(dataType: TypeInformation[_]): Boolean = dataType == STRING_TYPE_INFO
def isBoolean(dataType: TypeInformation[_]): Boolean = dataType == BOOLEAN_TYPE_INFO
def isDecimal(dataType: TypeInformation[_]): Boolean = dataType == BIG_DEC_TYPE_INFO
def isComparable(dataType: TypeInformation[_]): Boolean =
classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass)
def assertNumericExpr(
dataType: TypeInformation[_],
caller: String)
: ExprValidationResult = dataType match {
case _: NumericTypeInfo[_] =>
ValidationSuccess
} else {
case BIG_DEC_TYPE_INFO =>
ValidationSuccess
case _ =>
ValidationFailure(s"$caller requires numeric types, get $dataType here")
}
}
def assertOrderableExpr(dataType: TypeInformation[_], caller: String): ExprValidationResult = {
......
......@@ -37,11 +37,14 @@ object TypeCoercion {
def widerTypeOf(tp1: TypeInformation[_], tp2: TypeInformation[_]): Option[TypeInformation[_]] = {
(tp1, tp2) match {
case (tp1, tp2) if tp1 == tp2 => Some(tp1)
case (ti1, ti2) if ti1 == ti2 => Some(ti1)
case (_, STRING_TYPE_INFO) => Some(STRING_TYPE_INFO)
case (STRING_TYPE_INFO, _) => Some(STRING_TYPE_INFO)
case (_, BIG_DEC_TYPE_INFO) => Some(BIG_DEC_TYPE_INFO)
case (BIG_DEC_TYPE_INFO, _) => Some(BIG_DEC_TYPE_INFO)
case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) =>
val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2)
Some(numericWideningPrecedence(higherIndex))
......@@ -56,6 +59,8 @@ object TypeCoercion {
def canSafelyCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match {
case (_, STRING_TYPE_INFO) => true
case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true
case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) =>
if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) {
true
......@@ -71,21 +76,24 @@ object TypeCoercion {
* Note: This may lose information during the cast.
*/
def canCast(from: TypeInformation[_], to: TypeInformation[_]): Boolean = (from, to) match {
case (from, to) if from == to => true
case (fromTp, toTp) if fromTp == toTp => true
case (_, STRING_TYPE_INFO) => true
case (_, DATE_TYPE_INFO) => false // Date type not supported yet.
case (_, VOID_TYPE_INFO) => false // Void type not supported
case (_, CHAR_TYPE_INFO) => false // Character type not supported.
case (STRING_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (STRING_TYPE_INFO, BOOLEAN_TYPE_INFO) => true
case (STRING_TYPE_INFO, BIG_DEC_TYPE_INFO) => true
case (BOOLEAN_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (BOOLEAN_TYPE_INFO, BIG_DEC_TYPE_INFO) => true
case (_: NumericTypeInfo[_], BOOLEAN_TYPE_INFO) => true
case (BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO) => true
case (_: NumericTypeInfo[_], _: NumericTypeInfo[_]) => true
case (BIG_DEC_TYPE_INFO, _: NumericTypeInfo[_]) => true
case (_: NumericTypeInfo[_], BIG_DEC_TYPE_INFO) => true
case _ => false
}
......
......@@ -56,6 +56,7 @@ object TypeConverter {
case STRING_TYPE_INFO => VARCHAR
case STRING_VALUE_TYPE_INFO => VARCHAR
case DATE_TYPE_INFO => DATE
case BIG_DEC_TYPE_INFO => DECIMAL
case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO =>
throw new TableException("Character type is not supported.")
......@@ -74,6 +75,7 @@ object TypeConverter {
case DOUBLE => DOUBLE_TYPE_INFO
case VARCHAR | CHAR => STRING_TYPE_INFO
case DATE => DATE_TYPE_INFO
case DECIMAL => BIG_DEC_TYPE_INFO
case NULL =>
throw new TableException("Type NULL is not supported. " +
......
......@@ -98,7 +98,9 @@ class AggregationsITCase(
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val sqlQuery = "SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7)" +
val sqlQuery =
"SELECT avg(_1), avg(_2), avg(_3), avg(_4), avg(_5), avg(_6), count(_7), " +
" sum(CAST(_6 AS DECIMAL))" +
"FROM MyTable"
val ds = env.fromElements(
......@@ -108,7 +110,7 @@ class AggregationsITCase(
val result = tEnv.sql(sqlQuery)
val expected = "1,1,1,1,1.5,1.5,2"
val expected = "1,1,1,1,1.5,1.5,2,3.0"
val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
......
......@@ -157,6 +157,23 @@ class ExpressionsITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testDecimalLiteral(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
val t = env
.fromElements(
(BigDecimal("78.454654654654654").bigDecimal, BigDecimal("4E+9999").bigDecimal)
)
.toTable(tEnv, 'a, 'b)
.select('a, 'b, BigDecimal("11.2"), BigDecimal("11.2").bigDecimal)
val expected = "78.454654654654654,4E+9999,11.2,11.2"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
// Date literals not yet supported
@Ignore
@Test
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.table.expressions
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.Row
import org.apache.flink.api.table.expressions.utils.ExpressionTestBase
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.junit.Test
class DecimalTypeTest extends ExpressionTestBase {
@Test
def testDecimalLiterals(): Unit = {
// implicit double
testAllApis(
11.2,
"11.2",
"11.2",
"11.2")
// implicit double
testAllApis(
0.7623533651719233,
"0.7623533651719233",
"0.7623533651719233",
"0.7623533651719233")
// explicit decimal (with precision of 19)
testAllApis(
BigDecimal("1234567891234567891"),
"1234567891234567891p",
"1234567891234567891",
"1234567891234567891")
// explicit decimal (high precision, not SQL compliant)
testTableApi(
BigDecimal("123456789123456789123456789"),
"123456789123456789123456789p",
"123456789123456789123456789")
// explicit decimal (high precision, not SQL compliant)
testTableApi(
BigDecimal("12.3456789123456789123456789"),
"12.3456789123456789123456789p",
"12.3456789123456789123456789")
}
@Test
def testDecimalBorders(): Unit = {
testAllApis(
Double.MaxValue,
Double.MaxValue.toString,
Double.MaxValue.toString,
Double.MaxValue.toString)
testAllApis(
Double.MinValue,
Double.MinValue.toString,
Double.MinValue.toString,
Double.MinValue.toString)
testAllApis(
Double.MinValue.cast(FLOAT_TYPE_INFO),
s"${Double.MinValue}.cast(FLOAT)",
s"CAST(${Double.MinValue} AS FLOAT)",
Float.NegativeInfinity.toString)
testAllApis(
Byte.MinValue.cast(BYTE_TYPE_INFO),
s"(${Byte.MinValue}).cast(BYTE)",
s"CAST(${Byte.MinValue} AS TINYINT)",
Byte.MinValue.toString)
testAllApis(
Byte.MinValue.cast(BYTE_TYPE_INFO) - 1.cast(BYTE_TYPE_INFO),
s"(${Byte.MinValue}).cast(BYTE) - (1).cast(BYTE)",
s"CAST(${Byte.MinValue} AS TINYINT) - CAST(1 AS TINYINT)",
Byte.MaxValue.toString)
testAllApis(
Short.MinValue.cast(SHORT_TYPE_INFO),
s"(${Short.MinValue}).cast(SHORT)",
s"CAST(${Short.MinValue} AS SMALLINT)",
Short.MinValue.toString)
testAllApis(
Int.MinValue.cast(INT_TYPE_INFO) - 1,
s"(${Int.MinValue}).cast(INT) - 1",
s"CAST(${Int.MinValue} AS INT) - 1",
Int.MaxValue.toString)
testAllApis(
Long.MinValue.cast(LONG_TYPE_INFO),
s"(${Long.MinValue}L).cast(LONG)",
s"CAST(${Long.MinValue} AS BIGINT)",
Long.MinValue.toString)
}
@Test
def testDecimalCasting(): Unit = {
// from String
testTableApi(
"123456789123456789123456789".cast(BIG_DEC_TYPE_INFO),
"'123456789123456789123456789'.cast(DECIMAL)",
"123456789123456789123456789")
// from double
testAllApis(
'f3.cast(BIG_DEC_TYPE_INFO),
"f3.cast(DECIMAL)",
"CAST(f3 AS DECIMAL)",
"4.2")
// to double
testAllApis(
'f0.cast(DOUBLE_TYPE_INFO),
"f0.cast(DOUBLE)",
"CAST(f0 AS DOUBLE)",
"1.2345678912345679E8")
// to int
testAllApis(
'f4.cast(INT_TYPE_INFO),
"f4.cast(INT)",
"CAST(f4 AS INT)",
"123456789")
// to long
testAllApis(
'f4.cast(LONG_TYPE_INFO),
"f4.cast(LONG)",
"CAST(f4 AS BIGINT)",
"123456789")
// to boolean (not SQL compliant)
testTableApi(
'f1.cast(BOOLEAN_TYPE_INFO),
"f1.cast(BOOL)",
"true")
testTableApi(
'f5.cast(BOOLEAN_TYPE_INFO),
"f5.cast(BOOL)",
"false")
testTableApi(
BigDecimal("123456789.123456789123456789").cast(DOUBLE_TYPE_INFO),
"(123456789.123456789123456789p).cast(DOUBLE)",
"1.2345678912345679E8")
}
@Test
def testDecimalArithmetic(): Unit = {
// implicit cast to decimal
testAllApis(
'f1 + 12,
"f1 + 12",
"f1 + 12",
"123456789123456789123456801")
// implicit cast to decimal
testAllApis(
Literal(12) + 'f1,
"12 + f1",
"12 + f1",
"123456789123456789123456801")
// implicit cast to decimal
testAllApis(
'f1 + 12.3,
"f1 + 12.3",
"f1 + 12.3",
"123456789123456789123456801.3")
// implicit cast to decimal
testAllApis(
Literal(12.3) + 'f1,
"12.3 + f1",
"12.3 + f1",
"123456789123456789123456801.3")
testAllApis(
'f1 + 'f1,
"f1 + f1",
"f1 + f1",
"246913578246913578246913578")
testAllApis(
'f1 - 'f1,
"f1 - f1",
"f1 - f1",
"0")
testAllApis(
'f1 * 'f1,
"f1 * f1",
"f1 * f1",
"15241578780673678546105778281054720515622620750190521")
testAllApis(
'f1 / 'f1,
"f1 / f1",
"f1 / f1",
"1")
testAllApis(
'f1 % 'f1,
"f1 % f1",
"MOD(f1, f1)",
"0")
testAllApis(
-'f0,
"-f0",
"-f0",
"-123456789.123456789123456789")
}
@Test
def testDecimalComparison(): Unit = {
testAllApis(
'f1 < 12,
"f1 < 12",
"f1 < 12",
"false")
testAllApis(
'f1 > 12,
"f1 > 12",
"f1 > 12",
"true")
testAllApis(
'f1 === 12,
"f1 === 12",
"f1 = 12",
"false")
testAllApis(
'f5 === 0,
"f5 === 0",
"f5 = 0",
"true")
testAllApis(
'f1 === BigDecimal("123456789123456789123456789"),
"f1 === 123456789123456789123456789p",
"f1 = CAST('123456789123456789123456789' AS DECIMAL)",
"true")
testAllApis(
'f1 !== BigDecimal("123456789123456789123456789"),
"f1 !== 123456789123456789123456789p",
"f1 <> CAST('123456789123456789123456789' AS DECIMAL)",
"false")
testAllApis(
'f4 < 'f0,
"f4 < f0",
"f4 < f0",
"true")
// TODO add all tests if FLINK-4070 is fixed
testSqlApi(
"12 < f1",
"true")
}
// ----------------------------------------------------------------------------------------------
def testData = {
val testData = new Row(6)
testData.setField(0, BigDecimal("123456789.123456789123456789").bigDecimal)
testData.setField(1, BigDecimal("123456789123456789123456789").bigDecimal)
testData.setField(2, 42)
testData.setField(3, 4.2)
testData.setField(4, BigDecimal("123456789").bigDecimal)
testData.setField(5, BigDecimal("0.000").bigDecimal)
testData
}
def typeInfo = {
new RowTypeInfo(Seq(
BIG_DEC_TYPE_INFO,
BIG_DEC_TYPE_INFO,
INT_TYPE_INFO,
DOUBLE_TYPE_INFO,
BIG_DEC_TYPE_INFO,
BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]]
}
}
......@@ -20,15 +20,13 @@ package org.apache.flink.api.table.expressions
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.expressions.utils.ExpressionEvaluator
import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.Row
import org.apache.flink.api.table.expressions.{Expression, ExpressionParser}
import org.apache.flink.api.table.expressions.utils.ExpressionTestBase
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.junit.Assert.assertEquals
import org.junit.Test
class ScalarFunctionsTest {
class ScalarFunctionsTest extends ExpressionTestBase {
// ----------------------------------------------------------------------------------------------
// String functions
......@@ -36,19 +34,19 @@ class ScalarFunctionsTest {
@Test
def testSubstring(): Unit = {
testFunction(
testAllApis(
'f0.substring(2),
"f0.substring(2)",
"SUBSTRING(f0, 2)",
"his is a test String.")
testFunction(
testAllApis(
'f0.substring(2, 5),
"f0.substring(2, 5)",
"SUBSTRING(f0, 2, 5)",
"his i")
testFunction(
testAllApis(
'f0.substring(1, 'f7),
"f0.substring(1, f7)",
"SUBSTRING(f0, 1, f7)",
......@@ -57,25 +55,25 @@ class ScalarFunctionsTest {
@Test
def testTrim(): Unit = {
testFunction(
testAllApis(
'f8.trim(),
"f8.trim()",
"TRIM(f8)",
"This is a test String.")
testFunction(
testAllApis(
'f8.trim(removeLeading = true, removeTrailing = true, " "),
"trim(f8)",
"TRIM(f8)",
"This is a test String.")
testFunction(
testAllApis(
'f8.trim(removeLeading = false, removeTrailing = true, " "),
"f8.trim(TRAILING, ' ')",
"TRIM(TRAILING FROM f8)",
" This is a test String.")
testFunction(
testAllApis(
'f0.trim(removeLeading = true, removeTrailing = true, "."),
"trim(BOTH, '.', f0)",
"TRIM(BOTH '.' FROM f0)",
......@@ -84,13 +82,13 @@ class ScalarFunctionsTest {
@Test
def testCharLength(): Unit = {
testFunction(
testAllApis(
'f0.charLength(),
"f0.charLength()",
"CHAR_LENGTH(f0)",
"22")
testFunction(
testAllApis(
'f0.charLength(),
"charLength(f0)",
"CHARACTER_LENGTH(f0)",
......@@ -99,7 +97,7 @@ class ScalarFunctionsTest {
@Test
def testUpperCase(): Unit = {
testFunction(
testAllApis(
'f0.upperCase(),
"f0.upperCase()",
"UPPER(f0)",
......@@ -108,7 +106,7 @@ class ScalarFunctionsTest {
@Test
def testLowerCase(): Unit = {
testFunction(
testAllApis(
'f0.lowerCase(),
"f0.lowerCase()",
"LOWER(f0)",
......@@ -117,7 +115,7 @@ class ScalarFunctionsTest {
@Test
def testInitCap(): Unit = {
testFunction(
testAllApis(
'f0.initCap(),
"f0.initCap()",
"INITCAP(f0)",
......@@ -126,7 +124,7 @@ class ScalarFunctionsTest {
@Test
def testConcat(): Unit = {
testFunction(
testAllApis(
'f0 + 'f0,
"f0 + f0",
"f0||f0",
......@@ -135,13 +133,13 @@ class ScalarFunctionsTest {
@Test
def testLike(): Unit = {
testFunction(
testAllApis(
'f0.like("Th_s%"),
"f0.like('Th_s%')",
"f0 LIKE 'Th_s%'",
"true")
testFunction(
testAllApis(
'f0.like("%is a%"),
"f0.like('%is a%')",
"f0 LIKE '%is a%'",
......@@ -150,13 +148,13 @@ class ScalarFunctionsTest {
@Test
def testNotLike(): Unit = {
testFunction(
testAllApis(
!'f0.like("Th_s%"),
"!f0.like('Th_s%')",
"f0 NOT LIKE 'Th_s%'",
"false")
testFunction(
testAllApis(
!'f0.like("%is a%"),
"!f0.like('%is a%')",
"f0 NOT LIKE '%is a%'",
......@@ -165,13 +163,13 @@ class ScalarFunctionsTest {
@Test
def testSimilar(): Unit = {
testFunction(
testAllApis(
'f0.similar("_*"),
"f0.similar('_*')",
"f0 SIMILAR TO '_*'",
"true")
testFunction(
testAllApis(
'f0.similar("This (is)? a (test)+ Strin_*"),
"f0.similar('This (is)? a (test)+ Strin_*')",
"f0 SIMILAR TO 'This (is)? a (test)+ Strin_*'",
......@@ -180,13 +178,13 @@ class ScalarFunctionsTest {
@Test
def testNotSimilar(): Unit = {
testFunction(
testAllApis(
!'f0.similar("_*"),
"!f0.similar('_*')",
"f0 NOT SIMILAR TO '_*'",
"false")
testFunction(
testAllApis(
!'f0.similar("This (is)? a (test)+ Strin_*"),
"!f0.similar('This (is)? a (test)+ Strin_*')",
"f0 NOT SIMILAR TO 'This (is)? a (test)+ Strin_*'",
......@@ -195,19 +193,19 @@ class ScalarFunctionsTest {
@Test
def testMod(): Unit = {
testFunction(
testAllApis(
'f4.mod('f7),
"f4.mod(f7)",
"MOD(f4, f7)",
"2")
testFunction(
testAllApis(
'f4.mod(3),
"mod(f4, 3)",
"MOD(f4, 3)",
"2")
testFunction(
testAllApis(
'f4 % 3,
"mod(44, 3)",
"MOD(44, 3)",
......@@ -217,37 +215,43 @@ class ScalarFunctionsTest {
@Test
def testExp(): Unit = {
testFunction(
testAllApis(
'f2.exp(),
"f2.exp()",
"EXP(f2)",
math.exp(42.toByte).toString)
testFunction(
testAllApis(
'f3.exp(),
"f3.exp()",
"EXP(f3)",
math.exp(43.toShort).toString)
testFunction(
testAllApis(
'f4.exp(),
"f4.exp()",
"EXP(f4)",
math.exp(44.toLong).toString)
testFunction(
testAllApis(
'f5.exp(),
"f5.exp()",
"EXP(f5)",
math.exp(4.5.toFloat).toString)
testFunction(
testAllApis(
'f6.exp(),
"f6.exp()",
"EXP(f6)",
math.exp(4.6).toString)
testFunction(
testAllApis(
'f7.exp(),
"exp(3)",
"EXP(3)",
math.exp(3).toString)
testAllApis(
'f7.exp(),
"exp(3)",
"EXP(3)",
......@@ -256,31 +260,31 @@ class ScalarFunctionsTest {
@Test
def testLog10(): Unit = {
testFunction(
testAllApis(
'f2.log10(),
"f2.log10()",
"LOG10(f2)",
math.log10(42.toByte).toString)
testFunction(
testAllApis(
'f3.log10(),
"f3.log10()",
"LOG10(f3)",
math.log10(43.toShort).toString)
testFunction(
testAllApis(
'f4.log10(),
"f4.log10()",
"LOG10(f4)",
math.log10(44.toLong).toString)
testFunction(
testAllApis(
'f5.log10(),
"f5.log10()",
"LOG10(f5)",
math.log10(4.5.toFloat).toString)
testFunction(
testAllApis(
'f6.log10(),
"f6.log10()",
"LOG10(f6)",
......@@ -289,19 +293,25 @@ class ScalarFunctionsTest {
@Test
def testPower(): Unit = {
testFunction(
testAllApis(
'f2.power('f7),
"f2.power(f7)",
"POWER(f2, f7)",
math.pow(42.toByte, 3).toString)
testFunction(
testAllApis(
'f3.power('f6),
"f3.power(f6)",
"POWER(f3, f6)",
math.pow(43.toShort, 4.6D).toString)
testFunction(
testAllApis(
'f4.power('f5),
"f4.power(f5)",
"POWER(f4, f5)",
math.pow(44.toLong, 4.5.toFloat).toString)
testAllApis(
'f4.power('f5),
"f4.power(f5)",
"POWER(f4, f5)",
......@@ -310,31 +320,31 @@ class ScalarFunctionsTest {
@Test
def testLn(): Unit = {
testFunction(
testAllApis(
'f2.ln(),
"f2.ln()",
"LN(f2)",
math.log(42.toByte).toString)
testFunction(
testAllApis(
'f3.ln(),
"f3.ln()",
"LN(f3)",
math.log(43.toShort).toString)
testFunction(
testAllApis(
'f4.ln(),
"f4.ln()",
"LN(f4)",
math.log(44.toLong).toString)
testFunction(
testAllApis(
'f5.ln(),
"f5.ln()",
"LN(f5)",
math.log(4.5.toFloat).toString)
testFunction(
testAllApis(
'f6.ln(),
"f6.ln()",
"LN(f6)",
......@@ -343,102 +353,116 @@ class ScalarFunctionsTest {
@Test
def testAbs(): Unit = {
testFunction(
testAllApis(
'f2.abs(),
"f2.abs()",
"ABS(f2)",
"42")
testFunction(
testAllApis(
'f3.abs(),
"f3.abs()",
"ABS(f3)",
"43")
testFunction(
testAllApis(
'f4.abs(),
"f4.abs()",
"ABS(f4)",
"44")
testFunction(
testAllApis(
'f5.abs(),
"f5.abs()",
"ABS(f5)",
"4.5")
testFunction(
testAllApis(
'f6.abs(),
"f6.abs()",
"ABS(f6)",
"4.6")
testFunction(
testAllApis(
'f9.abs(),
"f9.abs()",
"ABS(f9)",
"42")
testFunction(
testAllApis(
'f10.abs(),
"f10.abs()",
"ABS(f10)",
"43")
testFunction(
testAllApis(
'f11.abs(),
"f11.abs()",
"ABS(f11)",
"44")
testFunction(
testAllApis(
'f12.abs(),
"f12.abs()",
"ABS(f12)",
"4.5")
testFunction(
testAllApis(
'f13.abs(),
"f13.abs()",
"ABS(f13)",
"4.6")
testAllApis(
'f15.abs(),
"f15.abs()",
"ABS(f15)",
"1231.1231231321321321111")
}
@Test
def testArithmeticFloorCeil(): Unit = {
testFunction(
testAllApis(
'f5.floor(),
"f5.floor()",
"FLOOR(f5)",
"4.0")
testFunction(
testAllApis(
'f5.ceil(),
"f5.ceil()",
"CEIL(f5)",
"5.0")
testFunction(
testAllApis(
'f3.floor(),
"f3.floor()",
"FLOOR(f3)",
"43")
testFunction(
testAllApis(
'f3.ceil(),
"f3.ceil()",
"CEIL(f3)",
"43")
testAllApis(
'f15.floor(),
"f15.floor()",
"FLOOR(f15)",
"-1232")
testAllApis(
'f15.ceil(),
"f15.ceil()",
"CEIL(f15)",
"-1231")
}
// ----------------------------------------------------------------------------------------------
def testFunction(
expr: Expression,
exprString: String,
sqlExpr: String,
expected: String): Unit = {
val testData = new Row(15)
def testData = {
val testData = new Row(16)
testData.setField(0, "This is a test String.")
testData.setField(1, true)
testData.setField(2, 42.toByte)
......@@ -454,8 +478,12 @@ class ScalarFunctionsTest {
testData.setField(12, -4.5.toFloat)
testData.setField(13, -4.6)
testData.setField(14, -3)
testData.setField(15, BigDecimal("-1231.1231231321321321111").bigDecimal)
testData
}
val typeInfo = new RowTypeInfo(Seq(
def typeInfo = {
new RowTypeInfo(Seq(
STRING_TYPE_INFO,
BOOLEAN_TYPE_INFO,
BYTE_TYPE_INFO,
......@@ -470,21 +498,7 @@ class ScalarFunctionsTest {
LONG_TYPE_INFO,
FLOAT_TYPE_INFO,
DOUBLE_TYPE_INFO,
INT_TYPE_INFO)).asInstanceOf[TypeInformation[Any]]
val exprResult = ExpressionEvaluator.evaluate(testData, typeInfo, expr)
assertEquals(expected, exprResult)
val exprStringResult = ExpressionEvaluator.evaluate(
testData,
typeInfo,
ExpressionParser.parseExpression(exprString))
assertEquals(expected, exprStringResult)
val exprSqlResult = ExpressionEvaluator.evaluate(testData, typeInfo, sqlExpr)
assertEquals(expected, exprSqlResult)
INT_TYPE_INFO,
BIG_DEC_TYPE_INFO)).asInstanceOf[TypeInformation[Any]]
}
}
......@@ -20,7 +20,7 @@ package org.apache.flink.api.table.expressions.utils
import org.apache.calcite.rel.logical.LogicalProject
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.tools.{Frameworks, RelBuilder}
import org.apache.flink.api.common.functions.{Function, MapFunction}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
......@@ -28,25 +28,29 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.{DataSet => JDataSet}
import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction}
import org.apache.flink.api.table.expressions.Expression
import org.apache.flink.api.table.expressions.{Expression, ExpressionParser}
import org.apache.flink.api.table.runtime.FunctionCompiler
import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig, TableEnvironment}
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.{BatchTableEnvironment, Row, TableConfig, TableEnvironment}
import org.junit.Assert._
import org.junit.{After, Before}
import org.mockito.Mockito._
import scala.collection.mutable
/**
* Utility to translate and evaluate an RexNode or Table API expression to a String.
* Base test class for expression tests.
*/
object ExpressionEvaluator {
abstract class ExpressionTestBase {
// TestCompiler that uses current class loader
class TestCompiler[T <: Function] extends FunctionCompiler[T] {
def compile(genFunc: GeneratedFunction[T]): Class[T] =
compile(getClass.getClassLoader, genFunc.name, genFunc.code)
}
private val testExprs = mutable.LinkedHashSet[(RexNode, String)]()
private def prepareTable(
typeInfo: TypeInformation[Any]): (String, RelBuilder, TableEnvironment) = {
// setup test utils
private val tableName = "testTable"
private val context = prepareContext(typeInfo)
private val planner = Frameworks.getPlanner(context._2.getFrameworkConfig)
private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
// create DataSetTable
val dataSetMock = mock(classOf[DataSet[Any]])
val jDataSetMock = mock(classOf[JDataSet[Any]])
......@@ -55,65 +59,128 @@ object ExpressionEvaluator {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
val tableName = "myTable"
tEnv.registerDataSet(tableName, dataSetMock)
// prepare RelBuilder
val relBuilder = tEnv.getRelBuilder
relBuilder.scan(tableName)
(tableName, relBuilder, tEnv)
(relBuilder, tEnv)
}
def evaluate(data: Any, typeInfo: TypeInformation[Any], sqlExpr: String): String = {
// create DataSetTable
val table = prepareTable(typeInfo)
// create RelNode from SQL expression
val planner = Frameworks.getPlanner(table._3.getFrameworkConfig)
val parsed = planner.parse("SELECT " + sqlExpr + " FROM " + table._1)
val validated = planner.validate(parsed)
val converted = planner.rel(validated)
val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0)
def testData: Any
evaluate(data, typeInfo, table._2, expr)
}
def typeInfo: TypeInformation[Any]
def evaluate(data: Any, typeInfo: TypeInformation[Any], expr: Expression): String = {
val table = prepareTable(typeInfo)
val env = table._3
val resolvedExpr =
env.asInstanceOf[BatchTableEnvironment].scan("myTable").select(expr).
getRelNode.asInstanceOf[LogicalProject].getChildExps.get(0)
evaluate(data, typeInfo, table._2, resolvedExpr)
@Before
def resetTestExprs() = {
testExprs.clear()
}
def evaluate(
data: Any,
typeInfo: TypeInformation[Any],
relBuilder: RelBuilder,
rexNode: RexNode): String = {
// generate code for Mapper
@After
def evaluateExprs() = {
val relBuilder = context._1
val config = new TableConfig()
val generator = new CodeGenerator(config, false, typeInfo)
val genExpr = generator.generateExpression(relBuilder.cast(rexNode, VARCHAR)) // cast to String
// cast expressions to String
val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, VARCHAR)).toSeq
// generate code
val resultType = new RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO))
val genExpr = generator.generateResultExpression(
resultType,
resultType.getFieldNames,
stringTestExprs)
val bodyCode =
s"""
|${genExpr.code}
|return ${genExpr.resultTerm};
|""".stripMargin
val genFunc = generator.generateFunction[MapFunction[Any, String]](
"TestFunction",
classOf[MapFunction[Any, String]],
bodyCode,
STRING_TYPE_INFO.asInstanceOf[TypeInformation[Any]])
resultType.asInstanceOf[TypeInformation[Any]])
// compile and evaluate
val clazz = new TestCompiler[MapFunction[Any, String]]().compile(genFunc)
val mapper = clazz.newInstance()
mapper.map(data)
val result = mapper.map(testData).asInstanceOf[Row]
// compare
testExprs
.zipWithIndex
.foreach {
case ((expr, expected), index) =>
assertEquals(s"Wrong result for: $expr", expected, result.productElement(index))
}
}
private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = {
// create RelNode from SQL expression
val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName")
val validated = planner.validate(parsed)
val converted = planner.rel(validated)
// extract RexNode
val expr: RexNode = converted.rel.asInstanceOf[LogicalProject].getChildExps.get(0)
testExprs.add((expr, expected))
planner.close()
}
private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): Unit = {
val env = context._2
val expr = env
.asInstanceOf[BatchTableEnvironment]
.scan(tableName)
.select(tableApiExpr)
.getRelNode
.asInstanceOf[LogicalProject]
.getChildExps
.get(0)
testExprs.add((expr, expected))
}
private def addTableApiTestExpr(tableApiString: String, expected: String): Unit = {
addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), expected)
}
def testAllApis(
expr: Expression,
exprString: String,
sqlExpr: String,
expected: String)
: Unit = {
addTableApiTestExpr(expr, expected)
addTableApiTestExpr(exprString, expected)
addSqlTestExpr(sqlExpr, expected)
}
def testTableApi(
expr: Expression,
exprString: String,
expected: String)
: Unit = {
addTableApiTestExpr(expr, expected)
addTableApiTestExpr(exprString, expected)
}
def testSqlApi(
sqlExpr: String,
expected: String)
: Unit = {
addSqlTestExpr(sqlExpr, expected)
}
// ----------------------------------------------------------------------------------------------
// TestCompiler that uses current class loader
class TestCompiler[T <: Function] extends FunctionCompiler[T] {
def compile(genFunc: GeneratedFunction[T]): Class[T] =
compile(getClass.getClassLoader, genFunc.name, genFunc.code)
}
}
......@@ -18,6 +18,7 @@
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
import org.apache.flink.api.table.Row
import org.junit.Test
import org.junit.Assert.assertEquals
......@@ -63,8 +64,13 @@ abstract class AggregateTestBase[T] {
finalAgg(rows)
}
assertEquals(expected, result)
(expected, result) match {
case (e: BigDecimal, r: BigDecimal) =>
// BigDecimal.equals() value and scale but we are only interested in value.
assert(e.compareTo(r) == 0)
case _ =>
assertEquals(expected, result)
}
}
}
......
......@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]]
......@@ -122,3 +124,31 @@ class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] {
override def aggregator = new DoubleAvgAggregate()
}
class DecimalAvgAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("987654321000000"),
new BigDecimal("-0.000000000012345"),
null,
new BigDecimal("0.000000000012345"),
new BigDecimal("-987654321000000"),
null,
new BigDecimal("0")
),
Seq(
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
BigDecimal.ZERO,
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalAvgAggregate()
}
......@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]]
......@@ -141,3 +143,35 @@ class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] {
override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate()
}
class DecimalMaxAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("1000.000001"),
new BigDecimal("-1"),
new BigDecimal("-999.998999"),
null,
new BigDecimal("0"),
new BigDecimal("-999.999"),
null,
new BigDecimal("999.999")
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("1000.000001"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalMaxAggregate()
}
......@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]]
......@@ -141,3 +143,35 @@ class BooleanMinAggregateTest extends AggregateTestBase[Boolean] {
override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate()
}
class DecimalMinAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("1000"),
new BigDecimal("-1"),
new BigDecimal("-999.998999"),
null,
new BigDecimal("0"),
new BigDecimal("-999.999"),
null,
new BigDecimal("999.999")
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("-999.999"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalMinAggregate()
}
......@@ -18,6 +18,8 @@
package org.apache.flink.api.table.runtime.aggregate
import java.math.BigDecimal
abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
private val numeric: Numeric[T] = implicitly[Numeric[T]]
......@@ -97,3 +99,39 @@ class DoubleSumAggregateTest extends SumAggregateTestBase[Double] {
override def aggregator: Aggregate[Double] = new DoubleSumAggregate
}
class DecimalSumAggregateTest extends AggregateTestBase[BigDecimal] {
override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(
new BigDecimal("1"),
new BigDecimal("2"),
new BigDecimal("3"),
null,
new BigDecimal("0"),
new BigDecimal("-1000"),
new BigDecimal("0.000000000002"),
new BigDecimal("1000"),
new BigDecimal("-0.000000000001"),
new BigDecimal("999.999"),
null,
new BigDecimal("4"),
new BigDecimal("-999.999"),
null
),
Seq(
null,
null,
null,
null,
null
)
)
override def expectedResults: Seq[BigDecimal] = Seq(
new BigDecimal("10.000000000001"),
null
)
override def aggregator: Aggregate[BigDecimal] = new DecimalSumAggregate()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册