提交 2848d4d4 编写于 作者: H huangxingbo 提交者: Dian Fu

[FLINK-20620][table-planner-blink][python] Port BatchExecPythonCalc and...

[FLINK-20620][table-planner-blink][python] Port BatchExecPythonCalc and StreamExecPythonCalc to Java

This closes #14496.
上级 7ea384b9
......@@ -16,51 +16,26 @@
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.batch
package org.apache.flink.table.planner.plan.nodes.exec.batch;
import org.apache.flink.api.dag.Transformation
import org.apache.flink.core.memory.ManagedMemoryUseCase
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNodeBase}
import org.apache.flink.table.types.logical.RowType
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.calcite.rex.RexProgram
import java.util
import org.apache.calcite.rex.RexProgram;
/**
* Batch ExecNode for Python ScalarFunctions.
*
* <p>Note: This class can't be ported to Java,
* because java class can't extend scala interface with default implementation.
* FLINK-20620 will port this class to Java.
* Batch {@link ExecNode} for Python ScalarFunctions.
*/
class BatchExecPythonCalc(
calcProgram: RexProgram,
inputEdge: ExecEdge,
outputType: RowType,
description: String)
extends ExecNodeBase[RowData](
util.Collections.singletonList(inputEdge),
outputType,
description)
with BatchExecNode[RowData]
with CommonExecPythonCalc {
override protected def translateToPlanInternal(planner: PlannerBase): Transformation[RowData] = {
val inputTransform = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ret = createPythonOneInputTransformation(
inputTransform,
calcProgram,
"BatchExecPythonCalc",
getConfig(planner.getExecEnv, planner.getTableConfig))
public class BatchExecPythonCalc extends CommonExecPythonCalc implements BatchExecNode<RowData> {
if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
}
ret
}
public BatchExecPythonCalc(
RexProgram calcProgram,
ExecEdge inputEdge,
LogicalType outputType,
String description) {
super(calcProgram, inputEdge, outputType, description);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.common;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
/**
* Base class for exec Python Calc.
*/
public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> {
private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar." +
"RowDataPythonScalarFunctionOperator";
private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar.arrow." +
"RowDataArrowPythonScalarFunctionOperator";
private final RexProgram calcProgram;
public CommonExecPythonCalc(
RexProgram calcProgram,
ExecEdge inputEdge,
LogicalType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.calcProgram = calcProgram;
}
@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner);
OneInputTransformation<RowData, RowData> ret = createPythonOneInputTransformation(
inputTransform,
calcProgram,
getDesc(),
CommonPythonUtil.getConfig(planner.getExecEnv(), planner.getTableConfig()));
if (inputsContainSingleton()) {
ret.setParallelism(1);
ret.setMaxParallelism(1);
}
if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(planner.getTableConfig().getConfiguration())) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return ret;
}
private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(
Transformation<RowData> inputTransform,
RexProgram calcProgram,
String name,
Configuration config) {
List<RexCall> pythonRexCalls = calcProgram.getProjectList()
.stream()
.map(calcProgram::expandLocalRef)
.filter(x -> x instanceof RexCall)
.map(x -> (RexCall) x)
.collect(Collectors.toList());
List<Integer> forwardedFields = calcProgram.getProjectList()
.stream()
.map(calcProgram::expandLocalRef)
.filter(x -> x instanceof RexInputRef)
.map(x -> ((RexInputRef) x).getIndex())
.collect(Collectors.toList());
Tuple2<int[], PythonFunctionInfo[]> extractResult = extractPythonScalarFunctionInfos(pythonRexCalls);
int[] pythonUdfInputOffsets = extractResult.f0;
PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1;
LogicalType[] inputLogicalTypes =
((InternalTypeInfo<RowData>) inputTransform.getOutputType()).toRowFieldTypes();
InternalTypeInfo<RowData> pythonOperatorInputTypeInfo = (InternalTypeInfo<RowData>) inputTransform.getOutputType();
List<LogicalType> forwardedFieldsLogicalTypes = forwardedFields.stream()
.map(i -> inputLogicalTypes[i])
.collect(Collectors.toList());
List<LogicalType> pythonCallLogicalTypes = pythonRexCalls.stream()
.map(node -> FlinkTypeFactory.toLogicalType(node.getType()))
.collect(Collectors.toList());
List<LogicalType> fieldsLogicalTypes = new ArrayList<>();
fieldsLogicalTypes.addAll(forwardedFieldsLogicalTypes);
fieldsLogicalTypes.addAll(pythonCallLogicalTypes);
InternalTypeInfo<RowData> pythonOperatorResultTyeInfo = InternalTypeInfo.ofFields(
fieldsLogicalTypes.toArray(new LogicalType[0]));
OneInputStreamOperator<RowData, RowData> pythonOperator = getPythonScalarFunctionOperator(
config,
pythonOperatorInputTypeInfo,
pythonOperatorResultTyeInfo,
pythonUdfInputOffsets,
pythonFunctionInfos,
forwardedFields.stream().mapToInt(x -> x).toArray(),
calcProgram.getExprList().stream().anyMatch(
x -> PythonUtil.containsPythonCall(x, PythonFunctionKind.PANDAS)));
return new OneInputTransformation<>(
inputTransform,
name,
pythonOperator,
pythonOperatorResultTyeInfo,
inputTransform.getParallelism());
}
private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(
List<RexCall> rexCalls) {
LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>();
PythonFunctionInfo[] pythonFunctionInfos = rexCalls.stream()
.map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes))
.collect(Collectors.toList())
.toArray(new PythonFunctionInfo[rexCalls.size()]);
int[] udfInputOffsets = inputNodes.keySet()
.stream()
.map(x -> {
if (x instanceof RexInputRef) {
return ((RexInputRef) x).getIndex();
} else if (x instanceof RexFieldAccess) {
return ((RexFieldAccess) x).getField().getIndex();
}
return null;
})
.mapToInt(i -> i)
.toArray();
return Tuple2.of(udfInputOffsets, pythonFunctionInfos);
}
@SuppressWarnings("unchecked")
private OneInputStreamOperator<RowData, RowData> getPythonScalarFunctionOperator(
Configuration config,
InternalTypeInfo<RowData> inputRowTypeInfo,
InternalTypeInfo<RowData> outputRowTypeInfo,
int[] udfInputOffsets,
PythonFunctionInfo[] pythonFunctionInfos,
int[] forwardedFields,
boolean isArrow) {
Class clazz;
if (isArrow) {
clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
} else {
clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
}
try {
Constructor ctor = clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
RowType.class,
RowType.class,
int[].class,
int[].class);
return (OneInputStreamOperator<RowData, RowData>) ctor.newInstance(
config,
pythonFunctionInfos,
inputRowTypeInfo.toRowType(),
outputRowTypeInfo.toRowType(),
udfInputOffsets,
forwardedFields);
} catch (Exception e) {
throw new TableException("Python Scalar Function Operator constructed failed.", e);
}
}
}
......@@ -16,55 +16,26 @@
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.stream
package org.apache.flink.table.planner.plan.nodes.exec.stream;
import org.apache.flink.api.dag.Transformation
import org.apache.flink.core.memory.ManagedMemoryUseCase
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNodeBase}
import org.apache.flink.table.types.logical.RowType
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.calcite.rex.RexProgram
import java.util
import org.apache.calcite.rex.RexProgram;
/**
* Stream ExecNode for Python ScalarFunctions.
*
* <p>Note: This class can't be ported to Java,
* because java class can't extend scala interface with default implementation.
* FLINK-20620 will port this class to Java.
* Stream {@link ExecNode} for Python ScalarFunctions.
*/
class StreamExecPythonCalc(
calcProgram: RexProgram,
inputEdge: ExecEdge,
outputType: RowType,
description: String)
extends ExecNodeBase[RowData](
util.Collections.singletonList(inputEdge),
outputType,
description)
with StreamExecNode[RowData]
with CommonExecPythonCalc {
override protected def translateToPlanInternal(planner: PlannerBase): Transformation[RowData] = {
val inputTransform = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ret = createPythonOneInputTransformation(
inputTransform,
calcProgram,
"StreamExecPythonCalc",
getConfig(planner.getExecEnv, planner.getTableConfig))
public class StreamExecPythonCalc extends CommonExecPythonCalc implements StreamExecNode<RowData> {
if (inputsContainSingleton()) {
ret.setParallelism(1)
ret.setMaxParallelism(1)
}
if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
}
ret
}
public StreamExecPythonCalc(
RexProgram calcProgram,
ExecEdge inputEdge,
LogicalType outputType,
String description) {
super(calcProgram, inputEdge, outputType, description);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.utils;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.python.PythonFunction;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* A utility class used in PyFlink.
*/
public class CommonPythonUtil {
private static final Method convertLiteralToPython;
private static final String PYTHON_DEPENDENCY_UTILS_CLASS = "org.apache.flink.python.util.PythonDependencyUtils";
static {
convertLiteralToPython = loadConvertLiteralToPythonMethod();
}
private CommonPythonUtil() {
}
public static Class loadClass(String className) {
try {
return Class.forName(className, false, Thread.currentThread().getContextClassLoader());
} catch (ClassNotFoundException e) {
throw new TableException(
"The dependency of 'flink-python' is not present on the classpath.", e);
}
}
@SuppressWarnings("unchecked")
public static Configuration getConfig(
StreamExecutionEnvironment env,
TableConfig tableConfig) {
Class clazz = loadClass(PYTHON_DEPENDENCY_UTILS_CLASS);
try {
StreamExecutionEnvironment readEnv = getRealEnvironment(env);
Method method = clazz.getDeclaredMethod(
"configurePythonDependencies", List.class, Configuration.class);
Configuration config = (Configuration) method.invoke(
null, readEnv.getCachedFiles(), getMergedConfiguration(readEnv, tableConfig));
config.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId());
return config;
} catch (NoSuchFieldException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
throw new TableException("Method configurePythonDependencies accessed failed.", e);
}
}
public static PythonFunctionInfo createPythonFunctionInfo(
RexCall pythonRexCall,
Map<RexNode, Integer> inputNodes) {
SqlOperator operator = pythonRexCall.getOperator();
try {
if (operator instanceof ScalarSqlFunction) {
return createPythonFunctionInfo(pythonRexCall, inputNodes, ((ScalarSqlFunction) operator).scalarFunction());
} else if (operator instanceof TableSqlFunction) {
return createPythonFunctionInfo(pythonRexCall, inputNodes, ((TableSqlFunction) operator).udtf());
} else if (operator instanceof BridgingSqlFunction) {
return createPythonFunctionInfo(pythonRexCall, inputNodes, ((BridgingSqlFunction) operator).getDefinition());
}
} catch (InvocationTargetException | IllegalAccessException e) {
throw new TableException("Method convertLiteralToPython accessed failed. ", e);
}
throw new TableException(String.format("Unsupported Python SqlFunction %s.", operator));
}
@SuppressWarnings("unchecked")
public static boolean isPythonWorkerUsingManagedMemory(Configuration config) {
Class clazz = loadClass("org.apache.flink.python.PythonOptions");
try {
return config.getBoolean((ConfigOption<Boolean>) (clazz.getField("USE_MANAGED_MEMORY").get(null)));
} catch (IllegalAccessException | NoSuchFieldException e) {
throw new TableException("Field USE_MANAGED_MEMORY accessed failed.", e);
}
}
@SuppressWarnings("unchecked")
private static Method loadConvertLiteralToPythonMethod() {
Class clazz = loadClass("org.apache.flink.api.common.python.PythonBridgeUtils");
try {
return clazz.getMethod("convertLiteralToPython", RexLiteral.class, SqlTypeName.class);
} catch (NoSuchMethodException e) {
throw new TableException(
"Method convertLiteralToPython loaded failed.", e);
}
}
private static PythonFunctionInfo createPythonFunctionInfo(
RexCall pythonRexCall,
Map<RexNode, Integer> inputNodes,
FunctionDefinition functionDefinition) throws InvocationTargetException, IllegalAccessException {
ArrayList<Object> inputs = new ArrayList<>();
for (RexNode operand : pythonRexCall.getOperands()) {
if (operand instanceof RexCall) {
RexCall childPythonRexCall = (RexCall) operand;
PythonFunctionInfo argPythonInfo = createPythonFunctionInfo(childPythonRexCall, inputNodes);
inputs.add(argPythonInfo);
} else if (operand instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) operand;
inputs.add(convertLiteralToPython.invoke(
null, literal, literal.getType().getSqlTypeName()));
} else {
if (inputNodes.containsKey(operand)) {
inputs.add(inputNodes.get(operand));
} else {
Integer inputOffset = inputNodes.size();
inputs.add(inputOffset);
inputNodes.put(operand, inputOffset);
}
}
}
return new PythonFunctionInfo((PythonFunction) functionDefinition, inputs.toArray());
}
private static StreamExecutionEnvironment getRealEnvironment(StreamExecutionEnvironment env)
throws NoSuchFieldException, IllegalAccessException {
Field realExecEnvField = DummyStreamExecutionEnvironment.class.getDeclaredField("realExecEnv");
realExecEnvField.setAccessible(true);
while (env instanceof DummyStreamExecutionEnvironment) {
env = (StreamExecutionEnvironment) realExecEnvField.get(env);
}
return env;
}
private static Configuration getMergedConfiguration(
StreamExecutionEnvironment env,
TableConfig tableConfig) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Method method = StreamExecutionEnvironment.class.getDeclaredMethod("getConfiguration");
method.setAccessible(true);
Configuration config = new Configuration((Configuration) method.invoke(env));
config.addAll(tableConfig.getConfiguration());
return config;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.common
import org.apache.calcite.rex._
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.python.{PythonFunctionInfo, PythonFunctionKind}
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.common.CommonPythonBase
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc.ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc.PYTHON_SCALAR_FUNCTION_OPERATOR_NAME
import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.types.logical.RowType
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
trait CommonExecPythonCalc extends CommonPythonBase {
private def extractPythonScalarFunctionInfos(
rexCalls: Array[RexCall]): (Array[Int], Array[PythonFunctionInfo]) = {
// using LinkedHashMap to keep the insert order
val inputNodes = new mutable.LinkedHashMap[RexNode, Integer]()
val pythonFunctionInfos = rexCalls.map(createPythonFunctionInfo(_, inputNodes))
val udfInputOffsets = inputNodes.toArray
.map(_._1)
.collect {
case inputRef: RexInputRef => inputRef.getIndex
case fac: RexFieldAccess => fac.getField.getIndex
}
(udfInputOffsets, pythonFunctionInfos)
}
private def getPythonScalarFunctionOperator(
config: Configuration,
inputRowTypeInfo: InternalTypeInfo[RowData],
outputRowTypeInfo: InternalTypeInfo[RowData],
udfInputOffsets: Array[Int],
pythonFunctionInfos: Array[PythonFunctionInfo],
forwardedFields: Array[Int],
isArrow: Boolean)= {
val clazz = if (isArrow) {
loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME)
} else {
loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME)
}
val ctor = clazz.getConstructor(
classOf[Configuration],
classOf[Array[PythonFunctionInfo]],
classOf[RowType],
classOf[RowType],
classOf[Array[Int]],
classOf[Array[Int]])
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowTypeInfo.toRowType,
outputRowTypeInfo.toRowType,
udfInputOffsets,
forwardedFields)
.asInstanceOf[OneInputStreamOperator[RowData, RowData]]
}
def createPythonOneInputTransformation(
inputTransform: Transformation[RowData],
calcProgram: RexProgram,
name: String,
config: Configuration): OneInputTransformation[RowData, RowData] = {
val pythonRexCalls = calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case call: RexCall => call }
.toArray
val forwardedFields: Array[Int] = calcProgram.getProjectList
.map(calcProgram.expandLocalRef)
.collect { case inputRef: RexInputRef => inputRef.getIndex }
.toArray
val (pythonUdfInputOffsets, pythonFunctionInfos) =
extractPythonScalarFunctionInfos(pythonRexCalls)
val inputLogicalTypes =
inputTransform.getOutputType.asInstanceOf[InternalTypeInfo[RowData]].toRowFieldTypes
val pythonOperatorInputTypeInfo = inputTransform.getOutputType
.asInstanceOf[InternalTypeInfo[RowData]]
val pythonOperatorResultTyeInfo = InternalTypeInfo.ofFields(
forwardedFields.map(inputLogicalTypes(_)) ++
pythonRexCalls.map(node => FlinkTypeFactory.toLogicalType(node.getType)): _*)
val pythonOperator = getPythonScalarFunctionOperator(
config,
pythonOperatorInputTypeInfo,
pythonOperatorResultTyeInfo,
pythonUdfInputOffsets,
pythonFunctionInfos,
forwardedFields,
calcProgram.getExprList.asScala.exists(containsPythonCall(_, PythonFunctionKind.PANDAS)))
new OneInputTransformation(
inputTransform,
name,
pythonOperator,
pythonOperatorResultTyeInfo,
inputTransform.getParallelism
)
}
}
object CommonExecPythonCalc {
val PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar.RowDataPythonScalarFunctionOperator"
val ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar.arrow." +
"RowDataArrowPythonScalarFunctionOperator"
}
......@@ -43,8 +43,7 @@ class BatchPhysicalPythonCalc(
traitSet,
inputRel,
calcProgram,
outputRowType)
with CommonExecPythonCalc {
outputRowType) {
override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = {
new BatchPhysicalPythonCalc(cluster, traitSet, child, program, outputRowType)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册