提交 e9decac6 编写于 作者: T twalthr

[FLINK-7439] [table] Add possiblity to replace TypeExtractor call

上级 79c17afa
...@@ -297,7 +297,7 @@ optionally implemented. While some of these methods allow the system more effici ...@@ -297,7 +297,7 @@ optionally implemented. While some of these methods allow the system more effici
- `merge()` is required for many batch aggreagtions and session window aggregations. - `merge()` is required for many batch aggreagtions and session window aggregations.
- `resetAccumulator()` is required for many batch aggregations. - `resetAccumulator()` is required for many batch aggregations.
All methods of `AggregateFunction` must be declared as `public`, not `static` and named exactly as the names mentioned above. The methods `createAccumulator`, `getValue`, `getResultType`, and `getAccumulatorType` are defined in the `AggregateFunction` abstract class, while others are contracted methods. In order to define a table function, one has to extend the base class `org.apache.flink.table.functions.AggregateFunction` and implement one (or more) `accumulate` methods. The method `accumulate` can be overloaded with different custom types and arguments and also support variable arguments. All methods of `AggregateFunction` must be declared as `public`, not `static` and named exactly as the names mentioned above. The methods `createAccumulator`, `getValue`, `getResultType`, and `getAccumulatorType` are defined in the `AggregateFunction` abstract class, while others are contracted methods. In order to define a aggregate function, one has to extend the base class `org.apache.flink.table.functions.AggregateFunction` and implement one (or more) `accumulate` methods. The method `accumulate` can be overloaded with different parameter types and supports variable arguments.
Detailed documentation for all methods of `AggregateFunction` is given below. Detailed documentation for all methods of `AggregateFunction` is given below.
......
...@@ -83,7 +83,7 @@ abstract class ScalarFunction extends UserDefinedFunction { ...@@ -83,7 +83,7 @@ abstract class ScalarFunction extends UserDefinedFunction {
* more complex, custom, or composite types. * more complex, custom, or composite types.
* *
* @param signature signature of the method the operand types need to be determined * @param signature signature of the method the operand types need to be determined
* @return [[TypeInformation]] of operand types * @return [[TypeInformation]] of operand types
*/ */
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = { def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = {
signature.map { c => signature.map { c =>
......
...@@ -18,7 +18,10 @@ ...@@ -18,7 +18,10 @@
package org.apache.flink.table.functions package org.apache.flink.table.functions
import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.api.ValidationException
import org.apache.flink.util.Collector import org.apache.flink.util.Collector
/** /**
...@@ -119,4 +122,29 @@ abstract class TableFunction[T] extends UserDefinedFunction { ...@@ -119,4 +122,29 @@ abstract class TableFunction[T] extends UserDefinedFunction {
*/ */
def getResultType: TypeInformation[T] = null def getResultType: TypeInformation[T] = null
/**
* Returns [[TypeInformation]] about the operands of the evaluation method with a given
* signature.
*
* In order to perform operand type inference in SQL (especially when NULL is used) it might be
* necessary to determine the parameter [[TypeInformation]] of an evaluation method.
* By default Flink's type extraction facilities are used for this but might be wrong for
* more complex, custom, or composite types.
*
* @param signature signature of the method the operand types need to be determined
* @return [[TypeInformation]] of operand types
*/
def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = {
signature.map { c =>
try {
TypeExtractor.getForClass(c)
} catch {
case ite: InvalidTypesException =>
throw new ValidationException(
s"Parameter types of table function '${this.getClass.getCanonicalName}' cannot be " +
s"automatically determined. Please provide type information manually.")
}
}
}
} }
...@@ -25,7 +25,6 @@ import org.apache.calcite.sql.parser.SqlParserPos ...@@ -25,7 +25,6 @@ import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction
import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.functions.TableFunction import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
...@@ -38,44 +37,45 @@ import org.apache.flink.table.functions.utils.TableSqlFunction._ ...@@ -38,44 +37,45 @@ import org.apache.flink.table.functions.utils.TableSqlFunction._
*/ */
class TableSqlFunction( class TableSqlFunction(
name: String, name: String,
udtf: TableFunction[_], tableFunction: TableFunction[_],
rowTypeInfo: TypeInformation[_], rowTypeInfo: TypeInformation[_],
typeFactory: FlinkTypeFactory, typeFactory: FlinkTypeFactory,
functionImpl: FlinkTableFunctionImpl[_]) functionImpl: FlinkTableFunctionImpl[_])
extends SqlUserDefinedTableFunction( extends SqlUserDefinedTableFunction(
new SqlIdentifier(name, SqlParserPos.ZERO), new SqlIdentifier(name, SqlParserPos.ZERO),
ReturnTypes.CURSOR, ReturnTypes.CURSOR,
createOperandTypeInference(name, udtf, typeFactory), createOperandTypeInference(name, tableFunction, typeFactory),
createOperandTypeChecker(name, udtf), createOperandTypeChecker(name, tableFunction),
null, null,
functionImpl) { functionImpl) {
/** /**
* Get the user-defined table function. * Get the user-defined table function.
*/ */
def getTableFunction = udtf def getTableFunction: TableFunction[_] = tableFunction
/** /**
* Get the type information of the table returned by the table function. * Get the type information of the table returned by the table function.
*/ */
def getRowTypeInfo = rowTypeInfo def getRowTypeInfo: TypeInformation[_] = rowTypeInfo
/** /**
* Get additional mapping information if the returned table type is a POJO * Get additional mapping information if the returned table type is a POJO
* (POJO types have no deterministic field order). * (POJO types have no deterministic field order).
*/ */
def getPojoFieldMapping = functionImpl.fieldIndexes def getPojoFieldMapping: Array[Int] = functionImpl.fieldIndexes
override def isDeterministic: Boolean = udtf.isDeterministic override def isDeterministic: Boolean = tableFunction.isDeterministic
} }
object TableSqlFunction { object TableSqlFunction {
private[flink] def createOperandTypeInference( private[flink] def createOperandTypeInference(
name: String, name: String,
udtf: TableFunction[_], tableFunction: TableFunction[_],
typeFactory: FlinkTypeFactory) typeFactory: FlinkTypeFactory)
: SqlOperandTypeInference = { : SqlOperandTypeInference = {
/** /**
* Operand type inference based on [[TableFunction]] given information. * Operand type inference based on [[TableFunction]] given information.
*/ */
...@@ -87,14 +87,14 @@ object TableSqlFunction { ...@@ -87,14 +87,14 @@ object TableSqlFunction {
val operandTypeInfo = getOperandTypeInfo(callBinding) val operandTypeInfo = getOperandTypeInfo(callBinding)
val foundSignature = getEvalMethodSignature(udtf, operandTypeInfo) val foundSignature = getEvalMethodSignature(tableFunction, operandTypeInfo)
.getOrElse(throw new ValidationException( .getOrElse(throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" + s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(operandTypeInfo)} \n" + s"Actual: ${signatureToString(operandTypeInfo)} \n" +
s"Expected: ${signaturesToString(udtf, "eval")}")) s"Expected: ${signaturesToString(tableFunction, "eval")}"))
val inferredTypes = foundSignature val inferredTypes = tableFunction
.map(TypeExtractor.getForClass(_)) .getParameterTypes(foundSignature)
.map(typeFactory.createTypeFromTypeInfo(_, isNullable = true)) .map(typeFactory.createTypeFromTypeInfo(_, isNullable = true))
for (i <- operandTypes.indices) { for (i <- operandTypes.indices) {
...@@ -113,17 +113,17 @@ object TableSqlFunction { ...@@ -113,17 +113,17 @@ object TableSqlFunction {
private[flink] def createOperandTypeChecker( private[flink] def createOperandTypeChecker(
name: String, name: String,
udtf: TableFunction[_]) tableFunction: TableFunction[_])
: SqlOperandTypeChecker = { : SqlOperandTypeChecker = {
val signatures = getMethodSignatures(udtf, "eval") val signatures = getMethodSignatures(tableFunction, "eval")
/** /**
* Operand type checker based on [[TableFunction]] given information. * Operand type checker based on [[TableFunction]] given information.
*/ */
new SqlOperandTypeChecker { new SqlOperandTypeChecker {
override def getAllowedSignatures(op: SqlOperator, opName: String): String = { override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
s"$opName[${signaturesToString(udtf, "eval")}]" s"$opName[${signaturesToString(tableFunction, "eval")}]"
} }
override def getOperandCountRange: SqlOperandCountRange = { override def getOperandCountRange: SqlOperandCountRange = {
...@@ -147,14 +147,14 @@ object TableSqlFunction { ...@@ -147,14 +147,14 @@ object TableSqlFunction {
: Boolean = { : Boolean = {
val operandTypeInfo = getOperandTypeInfo(callBinding) val operandTypeInfo = getOperandTypeInfo(callBinding)
val foundSignature = getEvalMethodSignature(udtf, operandTypeInfo) val foundSignature = getEvalMethodSignature(tableFunction, operandTypeInfo)
if (foundSignature.isEmpty) { if (foundSignature.isEmpty) {
if (throwOnFailure) { if (throwOnFailure) {
throw new ValidationException( throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" + s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(operandTypeInfo)} \n" + s"Actual: ${signatureToString(operandTypeInfo)} \n" +
s"Expected: ${signaturesToString(udtf, "eval")}") s"Expected: ${signaturesToString(tableFunction, "eval")}")
} else { } else {
false false
} }
......
...@@ -209,7 +209,6 @@ class CorrelateTest extends TableTestBase { ...@@ -209,7 +209,6 @@ class CorrelateTest extends TableTestBase {
util.verifySql(sqlQuery, expected) util.verifySql(sqlQuery, expected)
} }
@Test @Test
def testScalarFunction(): Unit = { def testScalarFunction(): Unit = {
val util = batchTestUtil() val util = batchTestUtil()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册