提交 5e6e8515 编写于 作者: T Timo Walther

[FLINK-15487][table] Update code generation for new type inference

This updates the code generation for the new type inference and thus
completes FLINK-15487. Scalar function work with the types supported
by the planner. Tests added in this PR only test basic behavior. We
will need more tests per data type. But this is a follow up issue.

This closes #10960.
上级 75248b48
......@@ -26,7 +26,6 @@ import org.apache.flink.table.dataformat.{BinaryStringUtil, Decimal, _}
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.runtime.dataview.StateDataViewStore
import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getInternalClassForType
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
......@@ -36,11 +35,12 @@ import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
import org.apache.flink.types.Row
import java.lang.reflect.Method
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort}
import java.util.concurrent.atomic.AtomicInteger
import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField}
object CodeGenUtils {
// ------------------------------- DEFAULT TERMS ------------------------------------------
......@@ -116,7 +116,28 @@ object CodeGenUtils {
/**
* Retrieve the canonical name of a class type.
*/
def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName
def className[T](implicit m: Manifest[T]): String = {
val name = m.runtimeClass.getCanonicalName
if (name == null) {
throw new CodeGenException(
s"Class '${m.runtimeClass.getName}' does not have a canonical name. " +
s"Make sure it is statically accessible.")
}
name
}
/**
* Returns a term for representing the given class in Java code.
*/
def typeTerm(clazz: Class[_]): String = {
val name = clazz.getCanonicalName
if (name == null) {
throw new CodeGenException(
s"Class '${clazz.getName}' does not have a canonical name. " +
s"Make sure it is statically accessible.")
}
name
}
// when casting we first need to unbox Primitives, for example,
// float a = 1.0f;
......@@ -167,18 +188,6 @@ object CodeGenUtils {
case RAW => className[BinaryGeneric[_]]
}
/**
* Gets the boxed type term from external type info.
* We only use TypeInformation to store external type info.
*/
def boxedTypeTermForExternalType(t: DataType): String = {
if (t.getConversionClass == null) {
ClassLogicalTypeConverter.getDefaultExternalClassForType(t.getLogicalType).getCanonicalName
} else {
t.getConversionClass.getCanonicalName
}
}
/**
* Gets the default value for a primitive type, and null for generic types
*/
......@@ -682,11 +691,11 @@ object CodeGenUtils {
genToInternal(ctx, t)(term)
def genToInternal(ctx: CodeGeneratorContext, t: DataType): String => String = {
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
if (isConverterIdentity(t)) {
term => s"($iTerm) $term"
term => s"$term"
} else {
val eTerm = boxedTypeTermForExternalType(t)
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
val eTerm = typeTerm(t.getConversionClass)
val converter = ctx.addReusableObject(
DataFormatConverters.getConverterForDataType(t),
"converter")
......@@ -694,38 +703,70 @@ object CodeGenUtils {
}
}
/**
* Generates code for converting the given external source data type to the internal data format.
*
* Use this function for converting at the edges of the API.
*/
def genToInternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
term: String): String = {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
sourceDataType: DataType,
externalTerm: String)
: GeneratedExpression = {
val sourceType = sourceDataType.getLogicalType
val sourceClass = sourceDataType.getConversionClass
// convert external source type to internal format
val internalResultTerm = if (isInternalClass(sourceDataType)) {
s"$externalTerm"
} else {
genToInternal(ctx, t, term)
genToInternal(ctx, sourceDataType, externalTerm)
}
// extract null term from result term
if (sourceClass.isPrimitive) {
generateNonNullField(sourceType, internalResultTerm)
} else {
generateInputFieldUnboxing(ctx, sourceType, externalTerm, internalResultTerm)
}
}
def genToExternal(ctx: CodeGeneratorContext, t: DataType, term: String): String = {
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
if (isConverterIdentity(t)) {
s"($iTerm) $term"
def genToExternal(
ctx: CodeGeneratorContext,
targetType: DataType,
internalTerm: String): String = {
if (isConverterIdentity(targetType)) {
s"$internalTerm"
} else {
val eTerm = boxedTypeTermForExternalType(t)
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(targetType))
val eTerm = typeTerm(targetType.getConversionClass)
val converter = ctx.addReusableObject(
DataFormatConverters.getConverterForDataType(t),
DataFormatConverters.getConverterForDataType(targetType),
"converter")
s"($eTerm) $converter.toExternal(($iTerm) $term)"
s"($eTerm) $converter.toExternal(($iTerm) $internalTerm)"
}
}
/**
* Generates code for converting the internal data format to the given external target data type.
*
* Use this function for converting at the edges of the API.
*/
def genToExternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
term: String): String = {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
targetDataType: DataType,
internalExpr: GeneratedExpression)
: String = {
val targetType = fromDataTypeToLogicalType(targetDataType)
// convert internal format to target type
val externalResultTerm = if (isInternalClass(targetDataType)) {
s"(${boxedTypeTermForType(targetType)}) ${internalExpr.resultTerm}"
} else {
genToExternal(ctx, targetDataType, internalExpr.resultTerm)
}
// merge null term into the result term
if (targetDataType.getConversionClass.isPrimitive) {
externalResultTerm
} else {
genToExternal(ctx, t, term)
s"${internalExpr.nullTerm} ? null : ($externalResultTerm)"
}
}
......
......@@ -27,7 +27,7 @@ import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, req
import org.apache.flink.table.planner.codegen.GenerateUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens._
import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen}
import org.apache.flink.table.planner.codegen.calls.{BridgingSqlFunctionCallGen, FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen}
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction}
......@@ -41,6 +41,8 @@ import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName}
import org.apache.calcite.util.TimestampString
import org.apache.flink.table.functions.{ScalarFunction, UserDefinedFunction}
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import scala.collection.JavaConversions._
......@@ -482,7 +484,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
case (o@_, _) => o.accept(this)
}
generateCallExpression(ctx, call.getOperator, operands, resultType)
generateCallExpression(ctx, call, operands, resultType)
}
override def visitOver(over: RexOver): GeneratedExpression =
......@@ -498,10 +500,10 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
private def generateCallExpression(
ctx: CodeGeneratorContext,
operator: SqlOperator,
call: RexCall,
operands: Seq[GeneratedExpression],
resultType: LogicalType): GeneratedExpression = {
operator match {
call.getOperator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
......@@ -780,21 +782,25 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
tsf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)
case bf: BridgingSqlFunction if bf.getDefinition.isInstanceOf[ScalarFunction] =>
new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
// advanced scalar functions
case sqlOperator: SqlOperator =>
StringCallGen.generateCallExpression(ctx, operator, operands, resultType).getOrElse {
FunctionGenerator
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}
StringCallGen.generateCallExpression(ctx, call.getOperator, operands, resultType)
.getOrElse {
FunctionGenerator
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}
// unknown or invalid
case call@_ =>
......
......@@ -570,17 +570,20 @@ object GenerateUtils {
* Wrapper types can autoboxed to their corresponding primitive type (Integer -> int).
*
* @param ctx code generator context which maintains various code statements.
* @param fieldType type of field
* @param fieldTerm expression term of field to be unboxed
* @param inputType type of field
* @param inputTerm expression term of field to be unboxed
* @param inputUnboxingTerm unboxing/conversion term
* @return internal unboxed field representation
*/
def generateInputFieldUnboxing(
ctx: CodeGeneratorContext,
fieldType: LogicalType,
fieldTerm: String): GeneratedExpression = {
inputType: LogicalType,
inputTerm: String,
inputUnboxingTerm: String)
: GeneratedExpression = {
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val resultTypeTerm = primitiveTypeTermForType(inputType)
val defaultValue = primitiveDefaultValue(inputType)
val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables(
(resultTypeTerm, "result"),
......@@ -588,19 +591,19 @@ object GenerateUtils {
val wrappedCode = if (ctx.nullCheck) {
s"""
|$nullTerm = $fieldTerm == null;
|$nullTerm = $inputTerm == null;
|$resultTerm = $defaultValue;
|if (!$nullTerm) {
| $resultTerm = $fieldTerm;
| $resultTerm = $inputUnboxingTerm;
|}
|""".stripMargin.trim
} else {
s"""
|$resultTerm = $fieldTerm;
|$resultTerm = $inputUnboxingTerm;
|""".stripMargin.trim
}
GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType)
GeneratedExpression(resultTerm, nullTerm, wrappedCode, inputType)
}
/**
......@@ -659,7 +662,7 @@ object GenerateUtils {
case _ =>
val fieldTypeTerm = boxedTypeTermForType(inputType)
val inputCode = s"($fieldTypeTerm) $inputTerm"
generateInputFieldUnboxing(ctx, inputType, inputCode)
generateInputFieldUnboxing(ctx, inputType, inputCode, inputCode)
}
/**
......
......@@ -191,7 +191,7 @@ object LookupJoinCodeGenerator {
.map { e =>
val dataType = fromLogicalTypeToDataType(e.resultType)
val bType = if (isExternalArgs) {
boxedTypeTermForExternalType(dataType)
typeTerm(dataType.getConversionClass)
} else {
boxedTypeTermForType(e.resultType)
}
......
......@@ -103,7 +103,7 @@ class ImperativeAggCodeGen(
} else {
boxedTypeTermForType(fromDataTypeToLogicalType(externalAccType))
}
val accTypeExternalTerm: String = boxedTypeTermForExternalType(externalAccType)
val accTypeExternalTerm: String = typeTerm(externalAccType.getConversionClass)
val argTypes: Array[LogicalType] = {
val types = inputTypes ++ constantExprs.map(_.resultType)
......@@ -250,7 +250,7 @@ class ImperativeAggCodeGen(
def getValue(generator: ExprCodeGenerator): GeneratedExpression = {
val valueExternalTerm = newName("value_external")
val valueExternalTypeTerm = boxedTypeTermForExternalType(externalResultType)
val valueExternalTypeTerm = typeTerm(externalResultType.getConversionClass)
val valueInternalTerm = newName("value_internal")
val valueInternalTypeTerm = boxedTypeTermForType(internalResultType)
val nullTerm = newName("valueIsNull")
......@@ -277,8 +277,7 @@ class ImperativeAggCodeGen(
if (f >= inputTypes.length) {
// index to constant
val expr = constantExprs(f - inputTypes.length)
s"${expr.nullTerm} ? null : ${
genToExternal(ctx, externalInputTypes(index), expr.resultTerm)}"
genToExternalIfNeeded(ctx, externalInputTypes(index), expr)
} else {
// index to input field
val inputRef = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
......@@ -297,8 +296,7 @@ class ImperativeAggCodeGen(
var inputExpr = generator.generateExpression(inputRef.accept(rexNodeGen))
if (inputFieldCopy) inputExpr = inputExpr.deepCopy(ctx)
codes += inputExpr.code
val term = s"${genToExternal(ctx, externalInputTypes(index), inputExpr.resultTerm)}"
s"${inputExpr.nullTerm} ? null : $term"
genToExternalIfNeeded(ctx, externalInputTypes(index), inputExpr)
}
}
......
......@@ -554,7 +554,7 @@ object AggCodeGenHelper {
val singleIterableClass = classOf[SingleElementIterator[_]].getCanonicalName
val externalAccT = getAccumulatorTypeOfAggregateFunction(agg)
val javaField = boxedTypeTermForExternalType(externalAccT)
val javaField = typeTerm(externalAccT.getConversionClass)
val tmpAcc = newName("tmpAcc")
s"""
|final $singleIterableClass accIt$aggIndex = new $singleIterableClass();
......@@ -625,11 +625,10 @@ object AggCodeGenHelper {
agg, externalAccType, inputExprs.map(_.resultType))
val parameters = inputExprs.zipWithIndex.map {
case (expr, i) =>
s"${expr.nullTerm} ? null : " +
s"${ genToExternal(ctx, externalUDITypes(i), expr.resultTerm)}"
genToExternalIfNeeded(ctx, externalUDITypes(i), expr)
}
val javaTerm = boxedTypeTermForExternalType(externalAccType)
val javaTerm = typeTerm(externalAccType.getConversionClass)
val tmpAcc = newName("tmpAcc")
val innerCode =
s"""
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.codegen.calls
import java.lang.reflect.Method
import java.util.Collections
import org.apache.calcite.rex.{RexCall, RexCallBinding}
import org.apache.flink.table.functions.UserDefinedFunctionHelper.SCALAR_EVAL
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{genToExternalIfNeeded, genToInternalIfNeeded, typeTerm}
import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.extraction.utils.ExtractionUtils
import org.apache.flink.table.types.extraction.utils.ExtractionUtils.{createMethodSignatureString, isAssignable, isMethodInvokable, primitiveToWrapper}
import org.apache.flink.table.types.inference.TypeInferenceUtil
import org.apache.flink.table.types.logical.LogicalType
/**
* Generates a call to a user-defined [[ScalarFunction]] or [[TableFunction]] (future work).
*/
class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator {
private val function: BridgingSqlFunction = call.getOperator.asInstanceOf[BridgingSqlFunction]
private val udf: UserDefinedFunction = function.getDefinition.asInstanceOf[UserDefinedFunction]
override def generate(
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
returnType: LogicalType)
: GeneratedExpression = {
val inference = function.getTypeInference
// we could have implemented a dedicated code generation context but the closer we are to
// Calcite the more consistent is the type inference during the data type enrichment
val callContext = new OperatorBindingCallContext(
function.getDataTypeFactory,
udf,
RexCallBinding.create(
function.getTypeFactory,
call,
Collections.emptyList()))
// enrich argument types with conversion class
val adaptedCallContext = TypeInferenceUtil.adaptArguments(
inference,
callContext,
null)
val enrichedArgumentDataTypes = toScala(adaptedCallContext.getArgumentDataTypes)
verifyArgumentTypes(operands.map(_.resultType), enrichedArgumentDataTypes)
// enrich output types with conversion class
val enrichedOutputDataType = TypeInferenceUtil.inferOutputType(
adaptedCallContext,
inference.getOutputTypeStrategy)
verifyOutputType(returnType, enrichedOutputDataType)
// find runtime method and generate call
verifyImplementation(enrichedArgumentDataTypes, enrichedOutputDataType)
generateFunctionCall(ctx, operands, enrichedArgumentDataTypes, enrichedOutputDataType)
}
private def generateFunctionCall(
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
argumentDataTypes: Seq[DataType],
outputDataType: DataType)
: GeneratedExpression = {
val functionTerm = ctx.addReusableFunction(udf)
// operand conversion
val externalOperands = prepareExternalOperands(ctx, operands, argumentDataTypes)
val externalOperandTerms = externalOperands.map(_.resultTerm).mkString(", ")
// result conversion
val externalResultClass = outputDataType.getConversionClass
val externalResultTypeTerm = typeTerm(externalResultClass)
// Janino does not fully support the JVM spec:
// boolean b = (boolean) f(); where f returns Object
// This is not supported and we need to box manually.
val externalResultClassBoxed = primitiveToWrapper(externalResultClass)
val externalResultCasting = if (externalResultClass == externalResultClassBoxed) {
s"($externalResultTypeTerm)"
} else {
s"($externalResultTypeTerm) (${typeTerm(externalResultClassBoxed)})"
}
val externalResultTerm = ctx.addReusableLocalVariable(externalResultTypeTerm, "externalResult")
val internalExpr = genToInternalIfNeeded(ctx, outputDataType, externalResultTerm)
// function call
internalExpr.copy(code =
s"""
|${externalOperands.map(_.code).mkString("\n")}
|$externalResultTerm = $externalResultCasting $functionTerm
| .$SCALAR_EVAL($externalOperandTerms);
|${internalExpr.code}
|""".stripMargin)
}
private def prepareExternalOperands(
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
argumentDataTypes: Seq[DataType])
: Seq[GeneratedExpression] = {
operands
.zip(argumentDataTypes)
.map { case (operand, dataType) =>
operand.copy(resultTerm = genToExternalIfNeeded(ctx, dataType, operand))
}
}
private def verifyArgumentTypes(
operandTypes: Seq[LogicalType],
enrichedDataTypes: Seq[DataType])
: Unit = {
val enrichedTypes = enrichedDataTypes.map(_.getLogicalType)
operandTypes.zip(enrichedTypes).foreach { case (operandType, enrichedType) =>
// check that the logical type has not changed during the enrichment
// a nullability mismatch is acceptable if the enriched type can handle it
if (operandType != enrichedType && operandType.copy(true) != enrichedType) {
throw new CodeGenException(
s"Mismatch of function's argument data type '$enrichedType' and actual " +
s"argument type '$operandType'.")
}
}
// the data type class can only partially verify the conversion class,
// now is the time for the final check
enrichedDataTypes.foreach(dataType => {
if (!dataType.getLogicalType.supportsOutputConversion(dataType.getConversionClass)) {
throw new CodeGenException(
s"Data type '$dataType' does not support an output conversion " +
s"to class '${dataType.getConversionClass}'.")
}
})
}
private def verifyOutputType(
outputType: LogicalType,
enrichedDataType: DataType)
: Unit = {
val enrichedType = enrichedDataType.getLogicalType
// check that the logical type has not changed during the enrichment
// a nullability mismatch is acceptable if the output type can handle it
if (outputType != enrichedType && outputType != enrichedType.copy(true)) {
throw new CodeGenException(
s"Mismatch of expected output data type '$outputType' and function's " +
s"output type '$enrichedType'.")
}
// the data type class can only partially verify the conversion class,
// now is the time for the final check
if (!enrichedType.supportsInputConversion(enrichedDataType.getConversionClass)) {
throw new CodeGenException(
s"Data type '$enrichedDataType' does not support an input conversion " +
s"to class '${enrichedDataType.getConversionClass}'.")
}
}
private def verifyImplementation(
argumentDataTypes: Seq[DataType],
outputDataType: DataType)
: Unit = {
val methods = toScala(ExtractionUtils.collectMethods(udf.getClass, SCALAR_EVAL))
val argumentClasses = argumentDataTypes.map(_.getConversionClass).toArray
val outputClass = outputDataType.getConversionClass
// verifies regular JVM calling semantics
def methodMatches(method: Method): Boolean = {
isMethodInvokable(method, argumentClasses: _*) &&
isAssignable(outputClass, method.getReturnType, true)
}
if (!methods.exists(methodMatches)) {
throw new CodeGenException(
s"Could not find an implementation method in class '${typeTerm(udf.getClass)}' for " +
s"function '$function' that matches the following signature: \n" +
s"${createMethodSignatureString(SCALAR_EVAL, argumentClasses, outputClass)}")
}
}
}
......@@ -30,6 +30,7 @@ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.extraction.utils.ExtractionUtils
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType
......@@ -107,7 +108,7 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato
val resultUnboxing = if (resultClass.isPrimitive) {
GenerateUtils.generateNonNullField(returnType, resultTerm)
} else {
GenerateUtils.generateInputFieldUnboxing(ctx, returnType, resultTerm)
GenerateUtils.generateInputFieldUnboxing(ctx, returnType, resultTerm, resultTerm)
}
resultUnboxing.copy(code =
s"""
......@@ -126,6 +127,17 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato
prepareFunctionArgs(ctx, operands, paramClasses, func.getParameterTypes(paramClasses))
}
def genToInternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
term: String): String = {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(LogicalTypeDataTypeConverter.fromDataTypeToLogicalType(t))}) $term"
} else {
genToInternal(ctx, t, term)
}
}
}
object ScalarFunctionCallGen {
......@@ -161,10 +173,8 @@ object ScalarFunctionCallGen {
} else {
signatureTypes(i)
}
val externalResultTerm = genToExternalIfNeeded(
ctx, signatureType, operandExpr.resultTerm)
val exprOrNull = s"${operandExpr.nullTerm} ? null : ($externalResultTerm)"
operandExpr.copy(resultTerm = exprOrNull)
val externalResultTerm = genToExternalIfNeeded(ctx, signatureType, operandExpr)
operandExpr.copy(resultTerm = externalResultTerm)
}
}
}
......
......@@ -18,27 +18,37 @@
package org.apache.flink.table.planner.runtime.stream.sql;
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.InputGroup;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.catalog.CatalogFunction;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.catalog.ObjectPath;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.planner.codegen.CodeGenException;
import org.apache.flink.table.planner.factories.utils.TestCollectionTableFactory;
import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.utils.EncodingUtils;
import org.apache.flink.types.Row;
import org.junit.Test;
import java.util.ArrayList;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
/**
* Tests for catalog and system in stream table environment.
......@@ -400,7 +410,7 @@ public class FunctionITCase extends StreamingTestBase {
);
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData, new ArrayList<>(), -1);
TestCollectionTableFactory.initData(sourceData);
String sourceDDL = "create table t1(a int, b varchar, c int) with ('connector' = 'COLLECTION')";
String sinkDDL = "create table t2(a int, b varchar, c int) with ('connector' = 'COLLECTION')";
......@@ -421,4 +431,178 @@ public class FunctionITCase extends StreamingTestBase {
tEnv().sqlUpdate("drop table t1");
tEnv().sqlUpdate("drop table t2");
}
@Test
public void testPrimitiveScalarFunction() throws Exception {
final List<Row> sourceData = Arrays.asList(
Row.of(1, 1L, "-"),
Row.of(2, 2L, "--"),
Row.of(3, 3L, "---")
);
final List<Row> sinkData = Arrays.asList(
Row.of(1, 3L, "-"),
Row.of(2, 6L, "--"),
Row.of(3, 9L, "---")
);
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
tEnv().sqlUpdate("CREATE TABLE TestTable(a INT NOT NULL, b BIGINT NOT NULL, c STRING) WITH ('connector' = 'COLLECTION')");
tEnv().createTemporarySystemFunction("PrimitiveScalarFunction", PrimitiveScalarFunction.class);
tEnv().sqlUpdate("INSERT INTO TestTable SELECT a, PrimitiveScalarFunction(a, b, c), c FROM TestTable");
tEnv().execute("Test Job");
assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData));
}
@Test
public void testComplexScalarFunction() throws Exception {
final List<Row> sourceData = Arrays.asList(
Row.of(1, new byte[]{1, 2, 3}),
Row.of(2, new byte[]{2, 3, 4}),
Row.of(3, new byte[]{3, 4, 5}),
Row.of(null, null)
);
final List<Row> sinkData = Arrays.asList(
Row.of(1, "1+2012-12-12 12:12:12.123456789", "[1, 2, 3]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[1, 2, 3]"),
Row.of(2, "2+2012-12-12 12:12:12.123456789", "[2, 3, 4]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[2, 3, 4]"),
Row.of(3, "3+2012-12-12 12:12:12.123456789", "[3, 4, 5]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[3, 4, 5]"),
Row.of(null, "null+2012-12-12 12:12:12.123456789", "null+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "null")
);
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
tEnv().sqlUpdate(
"CREATE TABLE SourceTable(i INT, b BYTES) " +
"WITH ('connector' = 'COLLECTION')");
tEnv().sqlUpdate(
"CREATE TABLE SinkTable(i INT, s1 STRING, s2 STRING, d DECIMAL(5, 2), s3 STRING) " +
"WITH ('connector' = 'COLLECTION')");
tEnv().createTemporarySystemFunction("ComplexScalarFunction", ComplexScalarFunction.class);
tEnv().sqlUpdate(
"INSERT INTO SinkTable " +
"SELECT " +
" i, " +
" ComplexScalarFunction(i, TIMESTAMP '2012-12-12 12:12:12.123456789'), " +
" ComplexScalarFunction(b, TIMESTAMP '2012-12-12 12:12:12.123456789')," +
" ComplexScalarFunction(), " +
" ComplexScalarFunction(b) " +
"FROM SourceTable");
tEnv().execute("Test Job");
assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData));
}
@Test
public void testCustomScalarFunction() throws Exception {
final List<Row> sourceData = Arrays.asList(
Row.of(1),
Row.of(2),
Row.of(3),
Row.of((Integer) null)
);
final List<Row> sinkData = Arrays.asList(
Row.of(1, 1, 5),
Row.of(2, 2, 5),
Row.of(3, 3, 5),
Row.of(null, null, 5)
);
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
tEnv().sqlUpdate("CREATE TABLE SourceTable(i INT) WITH ('connector' = 'COLLECTION')");
tEnv().sqlUpdate("CREATE TABLE SinkTable(i1 INT, i2 INT, i3 INT) WITH ('connector' = 'COLLECTION')");
tEnv().createTemporarySystemFunction("CustomScalarFunction", CustomScalarFunction.class);
tEnv().sqlUpdate(
"INSERT INTO SinkTable " +
"SELECT " +
" i, " +
" CustomScalarFunction(i), " +
" CustomScalarFunction(CAST(NULL AS INT), 5, i, i) " +
"FROM SourceTable");
tEnv().execute("Test Job");
assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData));
}
@Test
public void testInvalidCustomScalarFunction() {
tEnv().sqlUpdate("CREATE TABLE SinkTable(s STRING) WITH ('connector' = 'COLLECTION')");
tEnv().createTemporarySystemFunction("CustomScalarFunction", CustomScalarFunction.class);
try {
tEnv().sqlUpdate(
"INSERT INTO SinkTable " +
"SELECT CustomScalarFunction('test')");
fail();
} catch (CodeGenException e) {
assertThat(
e,
hasMessage(
equalTo(
"Could not find an implementation method in class '" + CustomScalarFunction.class.getCanonicalName() +
"' for function 'CustomScalarFunction' that matches the following signature: \n" +
"java.lang.String eval(java.lang.String)")));
}
}
// --------------------------------------------------------------------------------------------
// Test functions
// --------------------------------------------------------------------------------------------
/**
* Function that takes and returns primitives.
*/
public static class PrimitiveScalarFunction extends ScalarFunction {
public long eval(int i, long l, String s) {
return i + l + s.length();
}
}
/**
* Function that is overloaded and takes use of annotations.
*/
public static class ComplexScalarFunction extends ScalarFunction {
public String eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object o, java.sql.Timestamp t) {
return EncodingUtils.objectToString(o) + "+" + t.toString();
}
public @DataTypeHint("DECIMAL(5, 2)") BigDecimal eval() {
return new BigDecimal("123.4"); // 1 digit is missing
}
public String eval(byte[] bytes) {
return Arrays.toString(bytes);
}
}
/**
* Function that has a custom type inference that is broader than the actual implementation.
*/
public static class CustomScalarFunction extends ScalarFunction {
public Integer eval(Integer... args) {
for (Integer o : args) {
if (o != null) {
return o;
}
}
return null;
}
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
return TypeInference.newBuilder()
.outputTypeStrategy(TypeStrategies.argument(0))
.build();
}
}
}
......@@ -97,6 +97,10 @@ object TestCollectionTableFactory {
val RESULT = new JLinkedList[Row]()
private var emitIntervalMS = -1L
def initData(sourceData: JList[Row]): Unit ={
initData(sourceData, List(), -1L)
}
def initData(sourceData: JList[Row],
dimData: JList[Row] = List(),
emitInterval: Long = -1L): Unit ={
......@@ -112,6 +116,8 @@ object TestCollectionTableFactory {
emitIntervalMS = -1L
}
def getResult: util.List[Row] = RESULT
def getCollectionSource(props: JMap[String, String]): CollectionTableSource = {
val properties = new DescriptorProperties()
properties.putProperties(props)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册