提交 3de1205c 编写于 作者: H huangxingbo 提交者: hequn8128

[FLINK-15972][python][table-planner][table-planner-blink] Add Python building...

[FLINK-15972][python][table-planner][table-planner-blink] Add Python building blocks to make sure the basic functionality of Python TableFunction could work

This close #11130.
上级 2d0841cb
......@@ -27,51 +27,48 @@ from pyflink.fn_execution import flink_fn_execution_pb2
from pyflink.serializers import PickleSerializer
SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
TABLE_FUNCTION_URN = "flink:transform:table_function:v1"
class ScalarFunctionOperation(Operation):
class StatelessFunctionOperation(Operation):
"""
An operation that will execute ScalarFunctions for each input element.
Base class of stateless function operation that will execute ScalarFunction or TableFunction for
each input element.
"""
def __init__(self, name, spec, counter_factory, sampler, consumers):
super(ScalarFunctionOperation, self).__init__(name, spec, counter_factory, sampler)
super(StatelessFunctionOperation, self).__init__(name, spec, counter_factory, sampler)
self.consumer = consumers['output'][0]
self._value_coder_impl = self.consumer.windowed_coder.wrapped_value_coder.get_impl()
self.variable_dict = {}
self.scalar_funcs = []
self.func = self._generate_func(self.spec.serialized_fn)
for scalar_func in self.scalar_funcs:
scalar_func.open(None)
self.user_defined_funcs = []
self.func = self.generate_func(self.spec.serialized_fn)
for user_defined_func in self.user_defined_funcs:
user_defined_func.open(None)
def setup(self):
super(ScalarFunctionOperation, self).setup()
super(StatelessFunctionOperation, self).setup()
def start(self):
with self.scoped_start_state:
super(ScalarFunctionOperation, self).start()
def process(self, o):
output_stream = self.consumer.output_stream
self._value_coder_impl.encode_to_stream(self.func(o.value), output_stream, True)
output_stream.maybe_flush()
super(StatelessFunctionOperation, self).start()
def finish(self):
super(ScalarFunctionOperation, self).finish()
super(StatelessFunctionOperation, self).finish()
def needs_finalization(self):
return False
def reset(self):
super(ScalarFunctionOperation, self).reset()
super(StatelessFunctionOperation, self).reset()
def teardown(self):
for scalar_func in self.scalar_funcs:
scalar_func.close(None)
for user_defined_func in self.user_defined_funcs:
user_defined_func.close(None)
def progress_metrics(self):
metrics = super(ScalarFunctionOperation, self).progress_metrics()
metrics = super(StatelessFunctionOperation, self).progress_metrics()
metrics.processed_elements.measured.output_element_counts.clear()
tag = None
receiver = self.receivers[0]
......@@ -79,21 +76,16 @@ class ScalarFunctionOperation(Operation):
str(tag)] = receiver.opcounter.element_counter.value()
return metrics
def _generate_func(self, udfs):
"""
Generates a lambda function based on udfs.
:param udfs: a list of the proto representation of the Python :class:`ScalarFunction`
:return: the generated lambda function
"""
scalar_functions = [self._extract_scalar_function(udf) for udf in udfs]
return eval('lambda value: [%s]' % ','.join(scalar_functions), self.variable_dict)
def generate_func(self, udfs):
pass
def _extract_scalar_function(self, scalar_function_proto):
def _extract_user_defined_function(self, user_defined_function_proto):
"""
Extracts scalar_function from the proto representation of a
:class:`ScalarFunction`.
Extracts user-defined-function from the proto representation of a
:class:`UserDefinedFunction`.
:param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
:param user_defined_function_proto: the proto representation of the Python
:class:`UserDefinedFunction`
"""
def _next_func_num():
if not hasattr(self, "_func_num"):
......@@ -102,19 +94,19 @@ class ScalarFunctionOperation(Operation):
self._func_num += 1
return self._func_num
scalar_func = cloudpickle.loads(scalar_function_proto.payload)
user_defined_func = cloudpickle.loads(user_defined_function_proto.payload)
func_name = 'f%s' % _next_func_num()
self.variable_dict[func_name] = scalar_func.eval
self.scalar_funcs.append(scalar_func)
func_args = self._extract_scalar_function_args(scalar_function_proto.inputs)
self.variable_dict[func_name] = user_defined_func.eval
self.user_defined_funcs.append(user_defined_func)
func_args = self._extract_user_defined_function_args(user_defined_function_proto.inputs)
return "%s(%s)" % (func_name, func_args)
def _extract_scalar_function_args(self, args):
def _extract_user_defined_function_args(self, args):
args_str = []
for arg in args:
if arg.HasField("udf"):
# for chaining Python UDF input: the input argument is a Python ScalarFunction
args_str.append(self._extract_scalar_function(arg.udf))
args_str.append(self._extract_user_defined_function(arg.udf))
elif arg.HasField("inputOffset"):
# the input argument is a column of the input row
args_str.append("value[%s]" % arg.inputOffset)
......@@ -162,15 +154,69 @@ class ScalarFunctionOperation(Operation):
return constant_value_name
class ScalarFunctionOperation(StatelessFunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers):
super(ScalarFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers)
def generate_func(self, udfs):
"""
Generates a lambda function based on udfs.
:param udfs: a list of the proto representation of the Python :class:`ScalarFunction`
:return: the generated lambda function
"""
scalar_functions = [self._extract_user_defined_function(udf) for udf in udfs]
return eval('lambda value: [%s]' % ','.join(scalar_functions), self.variable_dict)
def process(self, o):
output_stream = self.consumer.output_stream
self._value_coder_impl.encode_to_stream(self.func(o.value), output_stream, True)
output_stream.maybe_flush()
class TableFunctionOperation(StatelessFunctionOperation):
def __init__(self, name, spec, counter_factory, sampler, consumers):
super(TableFunctionOperation, self).__init__(
name, spec, counter_factory, sampler, consumers)
def generate_func(self, udtfs):
"""
Generates a lambda function based on udtfs.
:param udtfs: a list of the proto representation of the Python :class:`TableFunction`
:return: the generated lambda function
"""
table_function = self._extract_user_defined_function(udtfs[0])
return eval('lambda value: %s' % table_function, self.variable_dict)
def process(self, o):
output_stream = self.consumer.output_stream
for result in self._create_result(o.value):
self._value_coder_impl.encode_to_stream(result, output_stream, True)
output_stream.maybe_flush()
def _create_result(self, value):
result = self.func(value)
if result is not None:
yield from result
yield None
@bundle_processor.BeamTransformFactory.register_urn(
SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions)
def create(factory, transform_id, transform_proto, parameter, consumers):
def create_scalar_function(factory, transform_id, transform_proto, parameter, consumers):
return _create_user_defined_function_operation(
factory, transform_proto, consumers, parameter.udfs, ScalarFunctionOperation)
@bundle_processor.BeamTransformFactory.register_urn(
TABLE_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions)
def create_table_function(factory, transform_id, transform_proto, parameter, consumers):
return _create_user_defined_function_operation(
factory, transform_proto, consumers, parameter.udfs)
factory, transform_proto, consumers, parameter.udfs, TableFunctionOperation)
def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto,
operation_cls=ScalarFunctionOperation):
operation_cls):
output_tags = list(transform_proto.outputs.keys())
output_coders = factory.get_output_coders(transform_proto)
spec = operation_specs.WorkerDoFn(
......
......@@ -766,8 +766,8 @@ class TableEnvironment(object):
:param function: The python user-defined function to register.
:type function: pyflink.table.udf.UserDefinedFunctionWrapper
"""
self._j_tenv.registerFunction(name, function._judf(self._is_blink_planner,
self.get_config()._j_table_config))
self._j_tenv.registerFunction(name, function.java_user_defined_function(
self._is_blink_planner, self.get_config()._j_table_config))
@since("1.10.0")
def create_temporary_view(self, view_path, table):
......
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from pyflink.table import DataTypes
from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, \
PyFlinkBlinkStreamTableTestCase
class UserDefinedTableFunctionTests(object):
def test_table_function(self):
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c'],
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
self.t_env.register_function(
"multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()],
[DataTypes.BIGINT(), DataTypes.BIGINT()]))
self.t_env.register_function("condition_multi_emit", condition_multi_emit)
self.t_env.register_function(
"multi_num", udf(MultiNum(), [DataTypes.BIGINT()],
DataTypes.BIGINT()))
t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
t.join_lateral("multi_emit(a, multi_num(b)) as (x, y)") \
.left_outer_join_lateral("condition_multi_emit(x, y) as m") \
.select("x, y, m") \
.insert_into("Results")
self.t_env.execute("test")
actual = source_sink_utils.results()
self.assert_equals(actual,
["1,0,null", "1,1,null", "2,0,null", "2,1,null", "3,0,0", "3,0,1",
"3,0,2", "3,1,1", "3,1,2", "3,2,2", "3,3,null"])
class PyFlinkStreamUserDefinedTableFunctionTests(UserDefinedTableFunctionTests,
PyFlinkStreamTableTestCase):
pass
class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
PyFlinkBlinkStreamTableTestCase):
pass
class MultiEmit(TableFunction):
def eval(self, x, y):
for i in range(y):
yield x, i
@udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
result_types=DataTypes.BIGINT())
def condition_multi_emit(x, y):
if x == 3:
return range(y, x)
class MultiNum(ScalarFunction):
def eval(self, x):
return x * 2
if __name__ == '__main__':
import unittest
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
......@@ -24,7 +24,7 @@ from pyflink.java_gateway import get_gateway
from pyflink.table.types import DataType, _to_java_type
from pyflink.util import utils
__all__ = ['FunctionContext', 'ScalarFunction', 'udf']
__all__ = ['FunctionContext', 'ScalarFunction', 'TableFunction', 'udf', 'udtf']
class FunctionContext(object):
......@@ -86,6 +86,20 @@ class ScalarFunction(UserDefinedFunction):
pass
class TableFunction(UserDefinedFunction):
"""
Base interface for user-defined table function. A user-defined table function creates zero, one,
or multiple rows to a new row value.
"""
@abc.abstractmethod
def eval(self, *args):
"""
Method which defines the logic of the table function.
"""
pass
class DelegatingScalarFunction(ScalarFunction):
"""
Helper scalar function implementation for lambda expression and python function. It's for
......@@ -99,14 +113,27 @@ class DelegatingScalarFunction(ScalarFunction):
return self.func(*args)
class DelegationTableFunction(TableFunction):
"""
Helper table function implementation for lambda expression and python function. It's for
internal use only.
"""
def __init__(self, func):
self.func = func
def eval(self, *args):
return self.func(*args)
class UserDefinedFunctionWrapper(object):
"""
Wrapper for Python user-defined function. It handles things like converting lambda
Base Wrapper for Python user-defined function. It handles things like converting lambda
functions to user-defined functions, creating the Java user-defined function representation,
etc. It's for internal use only.
"""
def __init__(self, func, input_types, result_type, deterministic=None, name=None):
def __init__(self, func, input_types, deterministic=None, name=None):
if inspect.isclass(func) or (
not isinstance(func, UserDefinedFunction) and not callable(func)):
raise TypeError(
......@@ -122,14 +149,8 @@ class UserDefinedFunctionWrapper(object):
"Invalid input_type: input_type should be DataType but contains {}".format(
input_type))
if not isinstance(result_type, DataType):
raise TypeError(
"Invalid returnType: returnType should be DataType but is {}".format(result_type))
self._func = func
self._input_types = input_types
self._result_type = result_type
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__') else func.__class__.__name__)
......@@ -142,7 +163,26 @@ class UserDefinedFunctionWrapper(object):
self._deterministic = deterministic if deterministic is not None else (
func.is_deterministic() if isinstance(func, UserDefinedFunction) else True)
def _judf(self, is_blink_planner, table_config):
def java_user_defined_function(self, is_blink_planner, table_config):
pass
class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined scalar function.
"""
def __init__(self, func, input_types, result_type, deterministic, name):
super(UserDefinedScalarFunctionWrapper, self).__init__(
func, input_types, deterministic, name)
if not isinstance(result_type, DataType):
raise TypeError(
"Invalid returnType: returnType should be DataType but is {}".format(result_type))
self._result_type = result_type
self._judf_placeholder = None
def java_user_defined_function(self, is_blink_planner, table_config):
if self._judf_placeholder is None:
self._judf_placeholder = self._create_judf(is_blink_planner, table_config)
return self._judf_placeholder
......@@ -183,6 +223,69 @@ class UserDefinedFunctionWrapper(object):
return j_scalar_function
class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined table function.
"""
def __init__(self, func, input_types, result_types, deterministic=None, name=None):
super(UserDefinedTableFunctionWrapper, self).__init__(
func, input_types, deterministic, name)
if not isinstance(result_types, collections.Iterable):
result_types = [result_types]
for result_type in result_types:
if not isinstance(result_type, DataType):
raise TypeError(
"Invalid result_type: result_type should be DataType but contains {}".format(
result_type))
self._result_types = result_types
self._judtf_placeholder = None
def java_user_defined_function(self, is_blink_planner, table_config):
if self._judtf_placeholder is None:
self._judtf_placeholder = self._create_judtf(is_blink_planner, table_config)
return self._judtf_placeholder
def _create_judtf(self, is_blink_planner, table_config):
func = self._func
if not isinstance(self._func, UserDefinedFunction):
func = DelegationTableFunction(self._func)
import cloudpickle
serialized_func = cloudpickle.dumps(func)
gateway = get_gateway()
j_input_types = utils.to_jarray(gateway.jvm.TypeInformation,
[_to_java_type(i) for i in self._input_types])
j_result_types = utils.to_jarray(gateway.jvm.TypeInformation,
[_to_java_type(i) for i in self._result_types])
if is_blink_planner:
PythonTableUtils = gateway.jvm \
.org.apache.flink.table.planner.utils.python.PythonTableUtils
j_table_function = PythonTableUtils \
.createPythonTableFunction(table_config,
self._name,
bytearray(serialized_func),
j_input_types,
j_result_types,
self._deterministic,
_get_python_env())
else:
PythonTableUtils = gateway.jvm.PythonTableUtils
j_table_function = PythonTableUtils \
.createPythonTableFunction(self._name,
bytearray(serialized_func),
j_input_types,
j_result_types,
self._deterministic,
_get_python_env())
return j_table_function
# TODO: support to configure the python execution environment
def _get_python_env():
gateway = get_gateway()
......@@ -191,7 +294,11 @@ def _get_python_env():
def _create_udf(f, input_types, result_type, deterministic, name):
return UserDefinedFunctionWrapper(f, input_types, result_type, deterministic, name)
return UserDefinedScalarFunctionWrapper(f, input_types, result_type, deterministic, name)
def _create_udtf(f, input_types, result_types, deterministic, name):
return UserDefinedTableFunctionWrapper(f, input_types, result_types, deterministic, name)
def udf(f=None, input_types=None, result_type=None, deterministic=None, name=None):
......@@ -225,8 +332,8 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non
this function is guaranteed to always return the same result given the
same parameters. (default True)
:type deterministic: bool
:return: UserDefinedFunctionWrapper or function.
:rtype: UserDefinedFunctionWrapper or function
:return: UserDefinedScalarFunctionWrapper or function.
:rtype: UserDefinedScalarFunctionWrapper or function
"""
# decorator
if f is None:
......@@ -234,3 +341,44 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non
deterministic=deterministic, name=name)
else:
return _create_udf(f, input_types, result_type, deterministic, name)
def udtf(f=None, input_types=None, result_types=None, deterministic=None, name=None):
"""
Helper method for creating a user-defined table function.
Example:
::
>>> @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
... result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
... def range_emit(s, e):
... for i in range(e):
... yield s, i
>>> class MultiEmit(TableFunction):
... def eval(self, i):
... return range(i)
>>> multi_emit = udtf(MultiEmit(), DataTypes.BIGINT(), DataTypes.BIGINT())
:param f: user-defined table function.
:type f: function or UserDefinedFunction or type
:param input_types: the input data types.
:type input_types: list[DataType] or DataType
:param result_types: the result data types.
:type result_types: list[DataType] or DataType
:param name: the function name.
:type name: str
:param deterministic: the determinism of the function's results. True if and only if a call to
this function is guaranteed to always return the same result given the
same parameters. (default True)
:type deterministic: bool
:return: UserDefinedTableFunctionWrapper or function.
:rtype: UserDefinedTableFunctionWrapper or function
"""
# decorator
if f is None:
return functools.partial(_create_udtf, input_types=input_types, result_types=result_types,
deterministic=deterministic, name=name)
else:
return _create_udtf(f, input_types, result_types, deterministic, name)
......@@ -18,12 +18,14 @@
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.functions.python.{PythonEnv, PythonFunction}
import org.apache.flink.table.functions.{ScalarFunction, UserDefinedFunction}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{newName, primitiveDefaultValue, primitiveTypeTermForType}
import org.apache.flink.table.planner.codegen.Indenter.toISC
import org.apache.flink.table.runtime.generated.GeneratedFunction
import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter
import org.apache.flink.types.Row
/**
* A code generator for generating Python [[UserDefinedFunction]]s.
......@@ -32,6 +34,8 @@ object PythonFunctionCodeGenerator {
private val PYTHON_SCALAR_FUNCTION_NAME = "PythonScalarFunction"
private val PYTHON_TABLE_FUNCTION_NAME = "PythonTableFunction"
/**
* Generates a [[ScalarFunction]] for the specified Python user-defined function.
*
......@@ -123,4 +127,94 @@ object PythonFunctionCodeGenerator {
new GeneratedFunction(funcName, funcCode, ctx.references.toArray)
.newInstance(Thread.currentThread().getContextClassLoader)
}
/**
* Generates a [[TableFunction]] for the specified Python user-defined function.
*
* @param ctx The context of the code generator
* @param name name of the user-defined function
* @param serializedTableFunction serialized Python table function
* @param inputTypes input data types
* @param resultTypes expected result types
* @param deterministic the determinism of the function's results
* @param pythonEnv the Python execution environment
* @return instance of generated TableFunction
*/
def generateTableFunction(
ctx: CodeGeneratorContext,
name: String,
serializedTableFunction: Array[Byte],
inputTypes: Array[TypeInformation[_]],
resultTypes: Array[TypeInformation[_]],
deterministic: Boolean,
pythonEnv: PythonEnv): TableFunction[_] = {
val funcName = newName(PYTHON_TABLE_FUNCTION_NAME)
val inputParamCode = inputTypes.zipWithIndex.map { case (inputType, index) =>
s"${primitiveTypeTermForType(
TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType(inputType))} in$index"
}.mkString(", ")
val rowTypeTerm = classOf[Row].getCanonicalName
val typeInfoTypeTerm = classOf[TypeInformation[_]].getCanonicalName
val rowTypeInfoTerm = classOf[RowTypeInfo].getCanonicalName
val pythonEnvTypeTerm = classOf[PythonEnv].getCanonicalName
val serializedTableFunctionNameTerm =
ctx.addReusableObject(serializedTableFunction, "serializedTableFunction", "byte[]")
val pythonEnvNameTerm = ctx.addReusableObject(pythonEnv, "pythonEnv", pythonEnvTypeTerm)
val inputTypesCode = inputTypes
.map(ctx.addReusableObject(_, "inputType", typeInfoTypeTerm))
.mkString(", ")
val resultTypesCode = resultTypes
.map(ctx.addReusableObject(_, "resultType", typeInfoTypeTerm))
.mkString(", ")
val funcCode = j"""
|public class $funcName extends ${classOf[TableFunction[_]].getCanonicalName}<$rowTypeTerm>
| implements ${classOf[PythonFunction].getCanonicalName} {
|
| private static final long serialVersionUID = 1L;
|
| ${ctx.reuseMemberCode()}
|
| public $funcName(Object[] references) throws Exception {
| ${ctx.reuseInitCode()}
| }
|
| public void eval($inputParamCode) {
| }
|
| @Override
| public $typeInfoTypeTerm[] getParameterTypes(Class<?>[] signature) {
| return new $typeInfoTypeTerm[]{$inputTypesCode};
| }
|
| @Override
| public $typeInfoTypeTerm<$rowTypeTerm> getResultType() {
| return new $rowTypeInfoTerm(new $typeInfoTypeTerm[]{$resultTypesCode});
| }
|
| @Override
| public byte[] getSerializedPythonFunction() {
| return $serializedTableFunctionNameTerm;
| }
|
| @Override
| public $pythonEnvTypeTerm getPythonEnv() {
| return $pythonEnvNameTerm;
| }
|
| @Override
| public boolean isDeterministic() {
| return $deterministic;
| }
|
| @Override
| public String toString() {
| return "$name";
| }
|}
|""".stripMargin
new GeneratedFunction(funcName, funcCode, ctx.references.toArray)
.newInstance(Thread.currentThread().getContextClassLoader)
}
}
......@@ -31,7 +31,7 @@ import org.apache.flink.api.java.io.CollectionInputFormat
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.core.io.InputSplit
import org.apache.flink.table.api.{TableConfig, TableSchema, Types}
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.functions.python.PythonEnv
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, PythonFunctionCodeGenerator}
import org.apache.flink.table.sources.InputFormatTableSource
......@@ -69,6 +69,34 @@ object PythonTableUtils {
deterministic,
pythonEnv)
/**
* Creates a [[TableFunction]] for the specified Python TableFunction.
*
* @param funcName class name of the user-defined function. Must be a valid Java class identifier
* @param serializedTableFunction serialized Python table function
* @param inputTypes input data types
* @param resultTypes expected result types
* @param deterministic the determinism of the function's results
* @param pythonEnv the Python execution environment
* @return A generated Java TableFunction representation for the specified Python TableFunction
*/
def createPythonTableFunction(
config: TableConfig,
funcName: String,
serializedTableFunction: Array[Byte],
inputTypes: Array[TypeInformation[_]],
resultTypes: Array[TypeInformation[_]],
deterministic: Boolean,
pythonEnv: PythonEnv): TableFunction[_] =
PythonFunctionCodeGenerator.generateTableFunction(
CodeGeneratorContext(config),
funcName,
serializedTableFunction,
inputTypes,
resultTypes,
deterministic,
pythonEnv)
/**
* Wrap the unpickled python data with an InputFormat. It will be passed to
* PythonInputFormatTableSource later.
......
......@@ -18,11 +18,13 @@
package org.apache.flink.table.codegen
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.codegen.CodeGenUtils.{primitiveDefaultValue, primitiveTypeTermForTypeInfo, newName}
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.codegen.CodeGenUtils.{newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.functions.{UserDefinedFunction, ScalarFunction}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.functions.python.{PythonEnv, PythonFunction}
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.types.Row
/**
* A code generator for generating Python [[UserDefinedFunction]]s.
......@@ -31,6 +33,8 @@ object PythonFunctionCodeGenerator extends Compiler[UserDefinedFunction] {
private val PYTHON_SCALAR_FUNCTION_NAME = "PythonScalarFunction"
private val PYTHON_TABLE_FUNCTION_NAME = "PythonTableFunction"
/**
* Generates a [[ScalarFunction]] for the specified Python user-defined function.
*
......@@ -120,4 +124,97 @@ object PythonFunctionCodeGenerator extends Compiler[UserDefinedFunction] {
funcCode)
clazz.newInstance().asInstanceOf[ScalarFunction]
}
/**
* Generates a [[TableFunction]] for the specified Python user-defined function.
*
* @param name name of the user-defined function
* @param serializedTableFunction serialized Python table function
* @param inputTypes input data types
* @param resultTypes expected result types
* @param deterministic the determinism of the function's results
* @param pythonEnv the Python execution environment
* @return instance of generated TableFunction
*/
def generateTableFunction(
name: String,
serializedTableFunction: Array[Byte],
inputTypes: Array[TypeInformation[_]],
resultTypes: Array[TypeInformation[_]],
deterministic: Boolean,
pythonEnv: PythonEnv): TableFunction[_] = {
val funcName = newName(PYTHON_TABLE_FUNCTION_NAME)
val inputParamCode = inputTypes.zipWithIndex.map { case (inputType, index) =>
s"${primitiveTypeTermForTypeInfo(inputType)} in$index"
}.mkString(", ")
val encodingUtilsTypeTerm = classOf[EncodingUtils].getCanonicalName
val typeInfoTypeTerm = classOf[TypeInformation[_]].getCanonicalName
val rowTypeInfoTerm = classOf[RowTypeInfo].getCanonicalName
val rowTypeTerm = classOf[Row].getCanonicalName
val inputTypesCode = inputTypes.map(EncodingUtils.encodeObjectToString).map { inputType =>
s"""
|($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
| "$inputType", $typeInfoTypeTerm.class)
|""".stripMargin
}.mkString(", ")
val resultTypesCode = resultTypes.map(EncodingUtils.encodeObjectToString).map { resultType =>
s"""
|($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
| "$resultType", $typeInfoTypeTerm.class)
|""".stripMargin
}.mkString(", ")
val encodedScalarFunction = EncodingUtils.encodeBytesToBase64(serializedTableFunction)
val encodedPythonEnv = EncodingUtils.encodeObjectToString(pythonEnv)
val pythonEnvTypeTerm = classOf[PythonEnv].getCanonicalName
val funcCode = j"""
|public class $funcName extends ${classOf[TableFunction[_]].getCanonicalName}<$rowTypeTerm>
| implements ${classOf[PythonFunction].getCanonicalName} {
|
| private static final long serialVersionUID = 1L;
|
| public void eval($inputParamCode) {
| }
|
| @Override
| public $typeInfoTypeTerm[] getParameterTypes(Class<?>[] signature) {
| return new $typeInfoTypeTerm[]{$inputTypesCode};
| }
|
| @Override
| public $typeInfoTypeTerm<$rowTypeTerm> getResultType() {
| return new $rowTypeInfoTerm(new $typeInfoTypeTerm[]{$resultTypesCode});
| }
|
| @Override
| public byte[] getSerializedPythonFunction() {
| return $encodingUtilsTypeTerm.decodeBase64ToBytes("$encodedScalarFunction");
| }
|
| @Override
| public $pythonEnvTypeTerm getPythonEnv() {
| return ($pythonEnvTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
| "$encodedPythonEnv", $pythonEnvTypeTerm.class);
| }
|
| @Override
| public boolean isDeterministic() {
| return $deterministic;
| }
|
| @Override
| public String toString() {
| return "$name";
| }
|}
|""".stripMargin
val clazz = compile(
Thread.currentThread().getContextClassLoader,
funcName,
funcCode)
clazz.newInstance().asInstanceOf[TableFunction[_]]
}
}
......@@ -32,7 +32,7 @@ import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, Ro
import org.apache.flink.core.io.InputSplit
import org.apache.flink.table.api.{TableSchema, Types}
import org.apache.flink.table.codegen.PythonFunctionCodeGenerator
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.functions.python.PythonEnv
import org.apache.flink.table.sources.InputFormatTableSource
import org.apache.flink.types.Row
......@@ -67,6 +67,32 @@ object PythonTableUtils {
deterministic,
pythonEnv)
/**
* Creates a [[TableFunction]] for the specified Python TableFunction.
*
* @param funcName class name of the user-defined function. Must be a valid Java class identifier
* @param serializedTableFunction serialized Python table function
* @param inputTypes input data types
* @param resultTypes expected result types
* @param deterministic the determinism of the function's results
* @param pythonEnv the Python execution environment
* @return A generated Java TableFunction representation for the specified Python TableFunction
*/
def createPythonTableFunction(
funcName: String,
serializedTableFunction: Array[Byte],
inputTypes: Array[TypeInformation[_]],
resultTypes: Array[TypeInformation[_]],
deterministic: Boolean,
pythonEnv: PythonEnv): TableFunction[_] =
PythonFunctionCodeGenerator.generateTableFunction(
funcName,
serializedTableFunction,
inputTypes,
resultTypes,
deterministic,
pythonEnv)
/**
* Wrap the unpickled python data with an InputFormat. It will be passed to
* PythonInputFormatTableSource later.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册