提交 febce359 编写于 作者: H huangxingbo 提交者: Dian Fu

[FLINK-20702][python] Support map operation chained together in Python Table API

This closes #14473.
上级 b474d287
......@@ -48,6 +48,8 @@ def wrap_inputs_as_row(*args):
import pandas as pd
if type(args[0]) == pd.Series:
return pd.concat(args, axis=1)
elif len(args) == 1 and isinstance(args[0], (pd.DataFrame, Row, Tuple)):
return args[0]
else:
return Row(*args)
......
......@@ -70,14 +70,24 @@ class RowBasedOperationTests(object):
res = pd.concat([x.a, x.c + x.d], axis=1)
return res
def func2(x):
return x * 2
pandas_udf = udf(func,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')
t.map(pandas_udf).execute_insert("Results").wait()
pandas_udf_2 = udf(func2,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')
t.map(pandas_udf).map(pandas_udf_2).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2,4", "1,5", "1,14", "1,9", "2,7"])
self.assert_equals(actual, ["4,8", "2,10", "2,28", "2,18", "4,14"])
def test_flat_map(self):
t = self.t_env.from_elements(
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.rules.logical;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
/**
* Rule will merge Python {@link FlinkLogicalCalc} used in Map operation, Flatten {@link FlinkLogicalCalc}
* and Python {@link FlinkLogicalCalc} used in Map operation together.
*/
public class PythonMapMergeRule extends RelOptRule {
public static final PythonMapMergeRule INSTANCE = new PythonMapMergeRule();
private PythonMapMergeRule() {
super(operand(FlinkLogicalCalc.class,
operand(FlinkLogicalCalc.class,
operand(FlinkLogicalCalc.class, none()))),
"PythonMapMergeRule");
}
@Override
public boolean matches(RelOptRuleCall call) {
FlinkLogicalCalc topCalc = call.rel(0);
FlinkLogicalCalc middleCalc = call.rel(1);
FlinkLogicalCalc bottomCalc = call.rel(2);
RexProgram topProgram = topCalc.getProgram();
List<RexNode> topProjects = topProgram.getProjectList()
.stream()
.map(topProgram::expandLocalRef)
.collect(Collectors.toList());
if (topProjects.size() != 1 || PythonUtil.isNonPythonCall(topProjects.get(0)) ||
!PythonUtil.takesRowAsInput((RexCall) topProjects.get(0))) {
return false;
}
RexProgram bottomProgram = bottomCalc.getProgram();
List<RexNode> bottomProjects = bottomProgram.getProjectList()
.stream()
.map(bottomProgram::expandLocalRef)
.collect(Collectors.toList());
if (bottomProjects.size() != 1 || PythonUtil.isNonPythonCall(bottomProjects.get(0))) {
return false;
}
RexProgram middleProgram = middleCalc.getProgram();
if (middleProgram.getCondition() != null) {
return false;
}
List<RexNode> middleProjects = middleProgram.getProjectList()
.stream()
.map(middleProgram::expandLocalRef)
.collect(Collectors.toList());
int inputRowFieldCount = middleProgram.getInputRowType()
.getFieldList()
.get(0)
.getValue()
.getFieldList().size();
return isFlattenCalc(middleProjects, inputRowFieldCount) &&
isTopCalcTakesWholeMiddleCalcAsInputs((RexCall) topProjects.get(0), middleProjects.size());
}
private boolean isTopCalcTakesWholeMiddleCalcAsInputs(RexCall pythonCall, int inputColumnCount) {
List<RexNode> pythonCallInputs = pythonCall.getOperands();
if (pythonCallInputs.size() != inputColumnCount) {
return false;
}
for (int i = 0; i < pythonCallInputs.size(); i++) {
RexNode input = pythonCallInputs.get(i);
if (input instanceof RexInputRef) {
if (((RexInputRef) input).getIndex() != i) {
return false;
}
} else {
return false;
}
}
return true;
}
private boolean isFlattenCalc(List<RexNode> middleProjects, int inputRowFieldCount) {
if (inputRowFieldCount != middleProjects.size()) {
return false;
}
for (int i = 0; i < inputRowFieldCount; i++) {
RexNode middleProject = middleProjects.get(i);
if (middleProject instanceof RexFieldAccess) {
RexFieldAccess rexField = ((RexFieldAccess) middleProject);
if (rexField.getField().getIndex() != i) {
return false;
}
RexNode expr = rexField.getReferenceExpr();
if (expr instanceof RexInputRef) {
if (((RexInputRef) expr).getIndex() != 0) {
return false;
}
} else {
return false;
}
} else {
return false;
}
}
return true;
}
@Override
public void onMatch(RelOptRuleCall call) {
FlinkLogicalCalc topCalc = call.rel(0);
FlinkLogicalCalc middleCalc = call.rel(1);
FlinkLogicalCalc bottomCalc = call.rel(2);
RexProgram topProgram = topCalc.getProgram();
List<RexCall> topProjects = topProgram.getProjectList()
.stream()
.map(topProgram::expandLocalRef)
.map(x -> (RexCall) x)
.collect(Collectors.toList());
RexCall topPythonCall = topProjects.get(0);
// merge topCalc and middleCalc
RexCall newPythonCall = topPythonCall.clone(topPythonCall.getType(),
Collections.singletonList(RexInputRef.of(0, bottomCalc.getRowType())));
List<RexCall> topMiddleMergedProjects = Collections.singletonList(newPythonCall);
FlinkLogicalCalc topMiddleMergedCalc = new FlinkLogicalCalc(
middleCalc.getCluster(),
middleCalc.getTraitSet(),
bottomCalc,
RexProgram.create(
bottomCalc.getRowType(),
topMiddleMergedProjects,
null,
Collections.singletonList("f0"),
call.builder().getRexBuilder()));
// merge bottomCalc
RexBuilder rexBuilder = call.builder().getRexBuilder();
RexProgram mergedProgram = RexProgramBuilder.mergePrograms(
topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
Calc newCalc = topMiddleMergedCalc.copy(
topMiddleMergedCalc.getTraitSet(), bottomCalc.getInput(), mergedProgram);
call.transformTo(newCalc);
}
}
......@@ -383,8 +383,9 @@ object FlinkBatchRuleSets {
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
PythonCalcSplitRule.EXPAND_PROJECT,
PythonCalcSplitRule.PUSH_CONDITION,
PythonCalcSplitRule.REWRITE_PROJECT
)
PythonCalcSplitRule.REWRITE_PROJECT,
PythonMapMergeRule.INSTANCE
)
/**
* RuleSet to do physical optimize for batch
......
......@@ -386,8 +386,9 @@ object FlinkStreamRuleSets {
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
PythonCalcSplitRule.EXPAND_PROJECT,
PythonCalcSplitRule.PUSH_CONDITION,
PythonCalcSplitRule.REWRITE_PROJECT
)
PythonCalcSplitRule.REWRITE_PROJECT,
PythonMapMergeRule.INSTANCE
)
/**
* RuleSet to do physical optimize for stream
......
......@@ -102,6 +102,14 @@ object PythonUtil {
}
}
def takesRowAsInput(call: RexCall): Boolean = {
(call.getOperator match {
case sfc: ScalarSqlFunction => sfc.scalarFunction
case tfc: TableSqlFunction => tfc.udtf
case bsf: BridgingSqlFunction => bsf.getDefinition
}).asInstanceOf[PythonFunction].takesRowAsInput()
}
private[this] def isPythonFunction(
function: FunctionDefinition,
pythonFunctionKind: PythonFunctionKind): Boolean = {
......
......@@ -257,11 +257,20 @@ public class JavaUserDefinedScalarFunctions {
return Row.of(a + 1, Row.of(a * a));
}
public Row eval(Object... args) {
return Row.of(1, Row.of(2));
}
@Override
public TypeInformation<?> getResultType(Class<?>[] signature) {
return Types.ROW(BasicTypeInfo.INT_TYPE_INFO, Types.ROW(BasicTypeInfo.INT_TYPE_INFO));
}
@Override
public boolean takesRowAsInput() {
return true;
}
@Override
public String toString() {
return name;
......
<?xml version="1.0" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to you under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<Root>
<TestCase name="testMapOperationsChained">
<Resource name="ast">
<![CDATA[
LogicalProject(_c0=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f1).f0], _c1=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f1).f1])
+- LogicalTableScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalCalc(select=[f0.f0 AS _c0, f0.f1 AS _c1])
+- FlinkLogicalCalc(select=[pyFunc2(pyFunc2(pyFunc2(a, b, c))) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
</Root>
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.rules.logical
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.optimize.program._
import org.apache.flink.table.planner.plan.rules.{FlinkBatchRuleSets, FlinkStreamRuleSets}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.RowPythonScalarFunction
import org.apache.flink.table.planner.utils.TableTestBase
import org.apache.calcite.plan.hep.HepMatchOrder
import org.junit.{Before, Test}
/**
* Test for [[PythonMapMergeRule]].
*/
class PythonMapMergeRuleTest extends TableTestBase {
private val util = batchTestUtil()
@Before
def setup(): Unit = {
val programs = new FlinkChainedProgram[BatchOptimizeContext]()
programs.addLast(
"logical",
FlinkVolcanoProgramBuilder.newBuilder
.add(FlinkBatchRuleSets.LOGICAL_OPT_RULES)
.setRequiredOutputTraits(Array(FlinkConventions.LOGICAL))
.build())
programs.addLast(
"logical_rewrite",
FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
.add(FlinkStreamRuleSets.LOGICAL_REWRITE)
.build())
util.replaceBatchProgram(programs)
}
@Test
def testMapOperationsChained(): Unit = {
val sourceTable = util.addTableSource[(Int, Int, Int)]("source", 'a, 'b, 'c)
val func = new RowPythonScalarFunction("pyFunc2")
val result = sourceTable.map(func(withColumns('*)))
.map(func(withColumns('*)))
.map(func(withColumns('*)))
util.verifyRelPlan(result)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册