提交 9aa24d82 编写于 作者: H hequn8128 提交者: Dian Fu

[FLINK-15636][python] Support Python UDF in old planner under batch mode

This closes #10913.
上级 11007c0d
......@@ -144,8 +144,6 @@ $ python --version
$ python -m pip install apache-beam==2.15.0
{% endhighlight %}
<span class="label label-info">Note</span> Currently, Python UDF is supported in Blink planner both under streaming and batch mode while is only supported under streaming mode in old planner.
It supports to use both Java/Scala scalar functions and Python scalar functions in Python Table API and SQL. In order to define a Python scalar function, one can extend the base class `ScalarFunction` in `pyflink.table.udf` and implement an evaluation method. The behavior of a Python scalar function is determined by the evaluation method. An evaluation method must be named `eval`. Evaluation method can also support variable arguments, such as `eval(*args)`.
The following example shows how to define your own Java and Python hash code functions, register them in the TableEnvironment, and call them in a query. Note that you can configure your scalar function via a constructor before it is registered:
......
......@@ -144,8 +144,6 @@ $ python --version
$ python -m pip install apache-beam==2.15.0
{% endhighlight %}
<span class="label label-info">Note</span> Currently, Python UDF is supported in Blink planner both under streaming and batch mode while is only supported under streaming mode in old planner.
It supports to use both Java/Scala scalar functions and Python scalar functions in Python Table API and SQL. In order to define a Python scalar function, one can extend the base class `ScalarFunction` in `pyflink.table.udf` and implement an evaluation method. The behavior of a Python scalar function is determined by the evaluation method. An evaluation method must be named `eval`. Evaluation method can also support variable arguments, such as `eval(*args)`.
The following example shows how to define your own Java and Python hash code functions, register them in the TableEnvironment, and call them in a query. Note that you can configure your scalar function via a constructor before it is registered:
......
......@@ -743,8 +743,6 @@ class TableEnvironment(object):
:param function: The python user-defined function to register.
:type function: pyflink.table.udf.UserDefinedFunctionWrapper
"""
if not self._is_blink_planner and isinstance(self, BatchTableEnvironment):
raise Exception("Python UDF is not supported in old planner under batch mode!")
self._j_tenv.registerFunction(name, function._judf(self._is_blink_planner,
self.get_config()._j_table_config))
......
......@@ -27,7 +27,8 @@ from pyflink.table.udf import udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import (PyFlinkBlinkStreamTableTestCase,
PyFlinkBlinkBatchTableTestCase,
PyFlinkStreamTableTestCase)
PyFlinkStreamTableTestCase,
PyFlinkBatchTableTestCase)
class DependencyTests(object):
......@@ -62,6 +63,30 @@ class FlinkStreamDependencyTests(DependencyTests, PyFlinkStreamTableTestCase):
pass
class FlinkBatchDependencyTests(PyFlinkBatchTableTestCase):
def test_add_python_file(self):
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + str(uuid.uuid4()))
os.mkdir(python_file_dir)
python_file_path = os.path.join(python_file_dir, "test_dependency_manage_lib.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")
self.t_env.add_python_file(python_file_path)
def plus_two(i):
from test_dependency_manage_lib import add_two
return add_two(i)
self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(),
DataTypes.BIGINT()))
t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])\
.select("add_two(a), a")
result = self.collect(t)
self.assertEqual(result, ["3,1", "4,2", "5,3"])
class BlinkBatchDependencyTests(DependencyTests, PyFlinkBlinkBatchTableTestCase):
pass
......
......@@ -481,13 +481,18 @@ class PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
class PyFlinkBatchUserDefinedFunctionTests(PyFlinkBatchTableTestCase):
def test_invalid_register_udf(self):
self.assertRaises(
Exception,
lambda: self.t_env.register_function(
"add_one",
udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
)
def test_chaining_scalar_function(self):
self.t_env.register_function(
"add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
self.t_env.register_function(
"subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
self.t_env.register_function("add", add)
t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c'])\
.select("add(add_one(a), subtract_one(b)), c, 1")
result = self.collect(t)
self.assertEqual(result, ["3,1,1", "7,2,1", "4,3,1"])
class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.functions.python;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ConfigurationUtils;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.python.env.ProcessPythonEnvironmentManager;
import org.apache.flink.python.env.PythonDependencyInfo;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.runners.python.PythonScalarFunctionRunner;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter;
import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
/**
* The {@link RichFlatMapFunction} used to invoke Python {@link ScalarFunction} functions for the
* old planner.
*/
@Internal
public final class PythonScalarFunctionFlatMap
extends RichFlatMapFunction<Row, Row> implements ResultTypeQueryable<Row> {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(PythonScalarFunctionFlatMap.class);
/**
* The type serializer for the forwarded fields.
*/
private transient TypeSerializer<Row> forwardedInputSerializer;
/**
* The Python {@link ScalarFunction}s to be executed.
*/
private final PythonFunctionInfo[] scalarFunctions;
/**
* The input logical type.
*/
private final RowType inputType;
/**
* The output logical type.
*/
private final RowType outputType;
/**
* The offsets of udf inputs.
*/
private final int[] udfInputOffsets;
/**
* The offset of the fields which should be forwarded.
*/
private final int[] forwardedFields;
/**
* The udf input logical type.
*/
private transient RowType udfInputType;
/**
* The udf output logical type.
*/
private transient RowType udfOutputType;
/**
* The queue holding the input elements for which the execution results have not been received.
*/
private transient LinkedBlockingQueue<Row> forwardedInputQueue;
/**
* The queue holding the user-defined function execution results. The execution results are in
* the same order as the input elements.
*/
private transient LinkedBlockingQueue<Row> udfResultQueue;
/**
* The python config.
*/
private final PythonConfig config;
/**
* Use an AtomicBoolean because we start/stop bundles by a timer thread.
*/
private transient AtomicBoolean bundleStarted;
/**
* Max number of elements to include in a bundle.
*/
private transient int maxBundleSize;
/**
* The collector used to collect records.
*/
private transient Collector<Row> resultCollector;
/**
* Number of processed elements in the current bundle.
*/
private transient int elementCount;
/**
* The {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
*/
private transient PythonFunctionRunner<Row> pythonFunctionRunner;
public PythonScalarFunctionFlatMap(
Configuration config,
PythonFunctionInfo[] scalarFunctions,
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int[] forwardedFields) {
this.scalarFunctions = Preconditions.checkNotNull(scalarFunctions);
this.inputType = Preconditions.checkNotNull(inputType);
this.outputType = Preconditions.checkNotNull(outputType);
this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets);
this.forwardedFields = Preconditions.checkNotNull(forwardedFields);
this.config = new PythonConfig(Preconditions.checkNotNull(config));
}
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
this.elementCount = 0;
this.bundleStarted = new AtomicBoolean(false);
this.maxBundleSize = config.getMaxBundleSize();
if (this.maxBundleSize <= 0) {
this.maxBundleSize = PythonOptions.MAX_BUNDLE_SIZE.defaultValue();
LOG.error("Invalid value for the maximum bundle size. Using default value of " +
this.maxBundleSize + '.');
} else {
LOG.info("The maximum bundle size is configured to {}.", this.maxBundleSize);
}
if (config.getMaxBundleTimeMills() != PythonOptions.MAX_BUNDLE_TIME_MILLS.defaultValue()) {
LOG.info("Maximum bundle time takes no effect in old planner under batch mode. " +
"Config maximum bundle size instead! " +
"Under batch mode, bundle size should be enough to control both throughput and latency.");
}
forwardedInputQueue = new LinkedBlockingQueue<>();
udfResultQueue = new LinkedBlockingQueue<>();
udfInputType = new RowType(
Arrays.stream(udfInputOffsets)
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
udfOutputType = new RowType(outputType.getFields().subList(forwardedFields.length, outputType.getFieldCount()));
RowTypeInfo forwardedInputTypeInfo = new RowTypeInfo(
Arrays.stream(forwardedFields)
.mapToObj(i -> inputType.getFields().get(i))
.map(RowType.RowField::getType)
.map(TypeConversions::fromLogicalToDataType)
.map(TypeConversions::fromDataTypeToLegacyInfo)
.toArray(TypeInformation[]::new));
forwardedInputSerializer = forwardedInputTypeInfo.createSerializer(getRuntimeContext().getExecutionConfig());
this.pythonFunctionRunner = createPythonFunctionRunner();
this.pythonFunctionRunner.open();
}
@Override
public void flatMap(Row value, Collector<Row> out) throws Exception {
this.resultCollector = out;
bufferInput(value);
checkInvokeStartBundle();
pythonFunctionRunner.processElement(getUdfInput(value));
checkInvokeFinishBundleByCount();
emitResults();
}
/**
* Checks whether to invoke startBundle.
*/
private void checkInvokeStartBundle() throws Exception {
if (bundleStarted.compareAndSet(false, true)) {
pythonFunctionRunner.startBundle();
}
}
/**
* Checks whether to invoke finishBundle by elements count. Called in flatMap.
*/
private void checkInvokeFinishBundleByCount() throws Exception {
elementCount++;
if (elementCount >= maxBundleSize) {
invokeFinishBundle();
}
}
private void invokeFinishBundle() throws Exception {
if (bundleStarted.compareAndSet(true, false)) {
pythonFunctionRunner.finishBundle();
emitResults();
elementCount = 0;
}
}
private Row getUdfInput(Row element) {
return Row.project(element, udfInputOffsets);
}
private PythonEnv getPythonEnv() {
return scalarFunctions[0].getPythonFunction().getPythonEnv();
}
private PythonFunctionRunner<Row> createPythonFunctionRunner() throws IOException {
FnDataReceiver<Row> udfResultReceiver = input -> {
// handover to queue, do not block the result receiver thread
udfResultQueue.put(input);
};
return new PythonScalarFunctionRunner(
getRuntimeContext().getTaskName(),
udfResultReceiver,
scalarFunctions,
createPythonEnvironmentManager(),
udfInputType,
udfOutputType);
}
private PythonEnvironmentManager createPythonEnvironmentManager() throws IOException {
PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create(
config, getRuntimeContext().getDistributedCache());
PythonEnv pythonEnv = getPythonEnv();
if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) {
return new ProcessPythonEnvironmentManager(
dependencyInfo,
ConfigurationUtils.splitPaths(System.getProperty("java.io.tmpdir")),
System.getenv());
} else {
throw new UnsupportedOperationException(String.format(
"Execution type '%s' is not supported.", pythonEnv.getExecType()));
}
}
private void bufferInput(Row input) {
Row forwardedFieldsRow = Row.project(input, forwardedFields);
if (getRuntimeContext().getExecutionConfig().isObjectReuseEnabled()) {
forwardedFieldsRow = forwardedInputSerializer.copy(forwardedFieldsRow);
}
forwardedInputQueue.add(forwardedFieldsRow);
}
private void emitResults() {
Row udfResult;
while ((udfResult = udfResultQueue.poll()) != null) {
Row input = forwardedInputQueue.poll();
this.resultCollector.collect(Row.join(input, udfResult));
}
}
@Override
public TypeInformation<Row> getProducedType() {
return (TypeInformation<Row>) LegacyTypeInfoDataTypeConverter
.toLegacyTypeInfo(LogicalTypeDataTypeConverter.toDataType(outputType));
}
@Override
public void close() throws Exception {
try {
invokeFinishBundle();
if (pythonFunctionRunner != null) {
pythonFunctionRunner.close();
pythonFunctionRunner = null;
}
} finally {
super.close();
}
}
}
......@@ -54,7 +54,8 @@ class BatchOptimizer(
val decorPlan = RelDecorrelator.decorrelateQuery(expandedPlan)
val normalizedPlan = optimizeNormalizeLogicalPlan(decorPlan)
val logicalPlan = optimizeLogicalPlan(normalizedPlan)
optimizePhysicalPlan(logicalPlan, FlinkConventions.DATASET)
val logicalRewritePlan = optimizeLogicalRewritePlan(logicalPlan)
optimizePhysicalPlan(logicalRewritePlan, FlinkConventions.DATASET)
}
/**
......
......@@ -17,7 +17,7 @@
*/
package org.apache.flink.table.plan.nodes
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexNode}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexNode, RexProgram}
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.flink.table.api.TableException
import org.apache.flink.table.functions.python.{PythonFunction, PythonFunctionInfo, SimplePythonFunction}
......@@ -89,4 +89,18 @@ trait CommonPythonCalc {
new PythonFunctionInfo(pythonFunction, inputs.toArray)
}
}
private[flink] def getPythonRexCalls(calcProgram: RexProgram): Array[RexCall] = {
calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case call: RexCall => call }
.toArray
}
private[flink] def getForwardedFields(calcProgram: RexProgram): Array[Int] = {
calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case inputRef: RexInputRef => inputRef.getIndex }
.toArray
}
}
......@@ -18,11 +18,10 @@
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex._
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.java.DataSet
......@@ -31,7 +30,6 @@ import org.apache.flink.table.api.BatchQueryConfig
import org.apache.flink.table.api.internal.BatchTableEnvImpl
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.FunctionCodeGenerator
import org.apache.flink.table.plan.nodes.CommonCalc
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.FlatMapRunner
import org.apache.flink.types.Row
......@@ -40,7 +38,6 @@ import scala.collection.JavaConverters._
/**
* Flink RelNode which matches along with LogicalCalc.
*
*/
class DataSetCalc(
cluster: RelOptCluster,
......@@ -49,40 +46,18 @@ class DataSetCalc(
rowRelDataType: RelDataType,
calcProgram: RexProgram,
ruleDescription: String)
extends Calc(cluster, traitSet, input, calcProgram)
with CommonCalc
with DataSetRel {
override def deriveRowType(): RelDataType = rowRelDataType
extends DataSetCalcBase(
cluster,
traitSet,
input,
rowRelDataType,
calcProgram,
ruleDescription) {
override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = {
new DataSetCalc(cluster, traitSet, child, getRowType, program, ruleDescription)
}
override def toString: String = calcToString(calcProgram, getExpressionString)
override def explainTerms(pw: RelWriter): RelWriter = {
pw.input("input", getInput)
.item("select", selectionToString(calcProgram, getExpressionString))
.itemIf("where",
conditionToString(calcProgram, getExpressionString),
calcProgram.getCondition != null)
}
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
computeSelfCost(calcProgram, planner, rowCnt)
}
override def estimateRowCount(metadata: RelMetadataQuery): Double = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
estimateRowCount(calcProgram, rowCnt)
}
override def translateToPlan(
tableEnv: BatchTableEnvImpl,
queryConfig: BatchQueryConfig): DataSet[Row] = {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rex.RexProgram
import org.apache.flink.table.plan.nodes.CommonCalc
/**
* Base RelNode for data set calc.
*/
abstract class DataSetCalcBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
rowRelDataType: RelDataType,
calcProgram: RexProgram,
ruleDescription: String)
extends Calc(cluster, traitSet, input, calcProgram)
with CommonCalc
with DataSetRel {
override def deriveRowType(): RelDataType = rowRelDataType
override def toString: String = calcToString(calcProgram, getExpressionString)
override def explainTerms(pw: RelWriter): RelWriter = {
pw.input("input", getInput)
.item("select", selectionToString(calcProgram, getExpressionString))
.itemIf("where",
conditionToString(calcProgram, getExpressionString),
calcProgram.getCondition != null)
}
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
computeSelfCost(calcProgram, planner, rowCnt)
}
override def estimateRowCount(metadata: RelMetadataQuery): Double = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
estimateRowCount(calcProgram, rowCnt)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex._
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.BatchQueryConfig
import org.apache.flink.table.api.internal.BatchTableEnvImpl
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.functions.python.PythonFunctionInfo
import org.apache.flink.table.plan.nodes.CommonPythonCalc
import org.apache.flink.table.plan.nodes.dataset.DataSetPythonCalc.PYTHON_SCALAR_FUNCTION_FLAT_MAP_NAME
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.types.logical.RowType
import org.apache.flink.table.types.utils.TypeConversions
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
/**
* Flink RelNode for Python ScalarFunctions.
*/
class DataSetPythonCalc(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
rowRelDataType: RelDataType,
calcProgram: RexProgram,
ruleDescription: String)
extends DataSetCalcBase(
cluster,
traitSet,
input,
rowRelDataType,
calcProgram,
ruleDescription)
with CommonPythonCalc {
private lazy val inputSchema = new RowSchema(input.getRowType)
override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = {
new DataSetPythonCalc(cluster, traitSet, child, getRowType, program, ruleDescription)
}
override def translateToPlan(
tableEnv: BatchTableEnvImpl,
queryConfig: BatchQueryConfig): DataSet[Row] = {
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
val flatMapFunctionResultTypeInfo = new RowTypeInfo(
getForwardedFields(calcProgram).map(inputSchema.fieldTypeInfos.get(_)) ++
getPythonRexCalls(calcProgram).map(node => FlinkTypeFactory.toTypeInfo(node.getType)): _*)
// construct the Python ScalarFunction flatMap function
val flatMapFunctionInputRowType = TypeConversions.fromLegacyInfoToDataType(
inputSchema.typeInfo).getLogicalType.asInstanceOf[RowType]
val flatMapFunctionOutputRowType = TypeConversions.fromLegacyInfoToDataType(
flatMapFunctionResultTypeInfo).getLogicalType.asInstanceOf[RowType]
val flatMapFunction = getPythonScalarFunctionFlatMap(
tableEnv.getConfig.getConfiguration,
flatMapFunctionInputRowType,
flatMapFunctionOutputRowType,
calcProgram)
inputDS.flatMap(flatMapFunction).name(calcOpName(calcProgram, getExpressionString))
}
private[flink] def getPythonScalarFunctionFlatMap(
config: Configuration,
inputRowType: RowType,
outputRowType: RowType,
calcProgram: RexProgram) = {
val clazz = loadClass(PYTHON_SCALAR_FUNCTION_FLAT_MAP_NAME)
val ctor = clazz.getConstructor(
classOf[Configuration],
classOf[Array[PythonFunctionInfo]],
classOf[RowType],
classOf[RowType],
classOf[Array[Int]],
classOf[Array[Int]])
val (udfInputOffsets, pythonFunctionInfos) =
extractPythonScalarFunctionInfos(getPythonRexCalls(calcProgram))
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowType,
outputRowType,
udfInputOffsets,
getForwardedFields(calcProgram))
.asInstanceOf[RichFlatMapFunction[Row, Row]]
}
}
object DataSetPythonCalc {
val PYTHON_SCALAR_FUNCTION_FLAT_MAP_NAME =
"org.apache.flink.table.runtime.functions.python.PythonScalarFunctionFlatMap"
}
......@@ -21,7 +21,7 @@ package org.apache.flink.table.plan.nodes.datastream
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rex.{RexCall, RexInputRef, RexProgram}
import org.apache.calcite.rex.RexProgram
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.datastream.DataStream
......@@ -71,30 +71,16 @@ class DataStreamPythonCalc(
ruleDescription)
}
private lazy val pythonRexCalls = calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case call: RexCall => call }
.toArray
private lazy val forwardedFields: Array[Int] = calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case inputRef: RexInputRef => inputRef.getIndex }
.toArray
private lazy val (pythonUdfInputOffsets, pythonFunctionInfos) =
extractPythonScalarFunctionInfos(pythonRexCalls)
override def translateToPlan(
planner: StreamPlanner,
queryConfig: StreamQueryConfig): DataStream[CRow] = {
val inputDataStream =
getInput.asInstanceOf[DataStreamRel].translateToPlan(planner, queryConfig)
val inputParallelism = inputDataStream.getParallelism
val pythonOperatorResultTypeInfo = new RowTypeInfo(
forwardedFields.map(inputSchema.fieldTypeInfos.get(_)) ++
pythonRexCalls.map(node => FlinkTypeFactory.toTypeInfo(node.getType)): _*)
getForwardedFields(calcProgram).map(inputSchema.fieldTypeInfos.get(_)) ++
getPythonRexCalls(calcProgram).map(node => FlinkTypeFactory.toTypeInfo(node.getType)): _*)
// construct the Python operator
val pythonOperatorInputRowType = TypeConversions.fromLegacyInfoToDataType(
......@@ -105,7 +91,7 @@ class DataStreamPythonCalc(
planner.getConfig.getConfiguration,
pythonOperatorInputRowType,
pythonOperatorOutputRowType,
pythonUdfInputOffsets)
calcProgram)
inputDataStream
.transform(
......@@ -120,7 +106,7 @@ class DataStreamPythonCalc(
config: Configuration,
inputRowType: RowType,
outputRowType: RowType,
udfInputOffsets: Array[Int]) = {
calcProgram: RexProgram) = {
val clazz = loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME)
val ctor = clazz.getConstructor(
classOf[Configuration],
......@@ -129,13 +115,15 @@ class DataStreamPythonCalc(
classOf[RowType],
classOf[Array[Int]],
classOf[Array[Int]])
val (udfInputOffsets, pythonFunctionInfos) =
extractPythonScalarFunctionInfos(getPythonRexCalls(calcProgram))
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowType,
outputRowType,
udfInputOffsets,
forwardedFields)
getForwardedFields(calcProgram))
.asInstanceOf[OneInputStreamOperator[CRow, CRow]]
}
}
......
......@@ -196,6 +196,7 @@ object FlinkRuleSets {
DataSetAggregateRule.INSTANCE,
DataSetDistinctRule.INSTANCE,
DataSetCalcRule.INSTANCE,
DataSetPythonCalcRule.INSTANCE,
DataSetJoinRule.INSTANCE,
DataSetSingleRowJoinRule.INSTANCE,
DataSetScanRule.INSTANCE,
......
......@@ -18,12 +18,15 @@
package org.apache.flink.table.plan.rules.dataSet
import org.apache.calcite.plan.{RelOptRule, RelTraitSet}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.dataset.DataSetCalc
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc
import org.apache.flink.table.plan.util.PythonUtil.containsPythonCall
import scala.collection.JavaConverters._
class DataSetCalcRule
extends ConverterRule(
......@@ -32,21 +35,27 @@ class DataSetCalcRule
FlinkConventions.DATASET,
"DataSetCalcRule") {
def convert(rel: RelNode): RelNode = {
val calc: FlinkLogicalCalc = rel.asInstanceOf[FlinkLogicalCalc]
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASET)
val convInput: RelNode = RelOptRule.convert(calc.getInput, FlinkConventions.DATASET)
new DataSetCalc(
rel.getCluster,
traitSet,
convInput,
rel.getRowType,
calc.getProgram,
"DataSetCalcRule")
}
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc]
val program = calc.getProgram
!program.getExprList.asScala.exists(containsPythonCall)
}
def convert(rel: RelNode): RelNode = {
val calc: FlinkLogicalCalc = rel.asInstanceOf[FlinkLogicalCalc]
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASET)
val convInput: RelNode = RelOptRule.convert(calc.getInput, FlinkConventions.DATASET)
new DataSetCalc(
rel.getCluster,
traitSet,
convInput,
rel.getRowType,
calc.getProgram,
"DataSetCalcRule")
}
}
object DataSetCalcRule {
val INSTANCE: RelOptRule = new DataSetCalcRule
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.rules.dataSet
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.dataset.DataSetPythonCalc
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc
import org.apache.flink.table.plan.util.PythonUtil.containsPythonCall
import scala.collection.JavaConverters._
class DataSetPythonCalcRule
extends ConverterRule(
classOf[FlinkLogicalCalc],
FlinkConventions.LOGICAL,
FlinkConventions.DATASET,
"DataSetPythonCalcRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc]
val program = calc.getProgram
program.getExprList.asScala.exists(containsPythonCall)
}
def convert(rel: RelNode): RelNode = {
val calc: FlinkLogicalCalc = rel.asInstanceOf[FlinkLogicalCalc]
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASET)
val convInput: RelNode = RelOptRule.convert(calc.getInput, FlinkConventions.DATASET)
new DataSetPythonCalc(
rel.getCluster,
traitSet,
convInput,
rel.getRowType,
calc.getProgram,
"DataSetPythonCalcRule")
}
}
object DataSetPythonCalcRule {
val INSTANCE: RelOptRule = new DataSetPythonCalcRule
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册