提交 7d6715c2 编写于 作者: H huangxingbo 提交者: Dian Fu

[FLINK-20756][python] Fix the bug that it doesn't support field access of...

[FLINK-20756][python] Fix the bug that it doesn't support field access of expression containing Python UDF in the condition of Calc

This closes #14492.
上级 8da64d79
......@@ -377,7 +377,8 @@ object FlinkBatchRuleSets {
// merge calc after calc transpose
FlinkCalcMergeRule.INSTANCE,
// Rule that splits python ScalarFunctions from java/scala ScalarFunctions
PythonCalcSplitRule.SPLIT_REX_FIELD,
PythonCalcSplitRule.SPLIT_CONDITION_REX_FIELD,
PythonCalcSplitRule.SPLIT_PROJECTION_REX_FIELD,
PythonCalcSplitRule.SPLIT_CONDITION,
PythonCalcSplitRule.SPLIT_PROJECT,
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
......
......@@ -381,7 +381,8 @@ object FlinkStreamRuleSets {
//Rule that rewrites temporal join with extracted primary key
TemporalJoinRewriteWithUniqueKeyRule.INSTANCE,
// Rule that splits python ScalarFunctions from java/scala ScalarFunctions.
PythonCalcSplitRule.SPLIT_REX_FIELD,
PythonCalcSplitRule.SPLIT_CONDITION_REX_FIELD,
PythonCalcSplitRule.SPLIT_PROJECTION_REX_FIELD,
PythonCalcSplitRule.SPLIT_CONDITION,
PythonCalcSplitRule.SPLIT_PROJECT,
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
......
......@@ -168,12 +168,34 @@ abstract class PythonCalcSplitProjectionRuleBase(description: String)
}
}
abstract class PythonCalcSplitRexFieldRuleBase(description: String)
extends PythonCalcSplitRuleBase(description) {
override def needConvert(program: RexProgram, node: RexNode): Boolean = {
node match {
case x: RexFieldAccess => x.getReferenceExpr match {
case y: RexLocalRef if containsPythonCall(program.expandLocalRef(y)) => true
case _ => false
}
case _ => false
}
}
protected def containsFieldAccessAfterPythonCall(node: RexNode): Boolean = {
node match {
case call: RexCall => call.getOperands.exists(containsFieldAccessAfterPythonCall)
case x: RexFieldAccess => containsPythonCall(x.getReferenceExpr)
case _ => false
}
}
}
/**
* Rule that splits the RexField with the input of Python function contained in the projection of
* [[FlinkLogicalCalc]]s.
*/
object PythonCalcSplitRexFieldRule extends PythonCalcSplitRuleBase(
"PythonCalcSplitRexFieldRule") {
object PythonCalcSplitProjectionRexFieldRule extends PythonCalcSplitRexFieldRuleBase(
"PythonCalcSplitProjectionRexFieldRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc]
......@@ -182,28 +204,31 @@ object PythonCalcSplitRexFieldRule extends PythonCalcSplitRuleBase(
projects.exists(containsFieldAccessAfterPythonCall)
}
override def needConvert(program: RexProgram, node: RexNode): Boolean = {
node match {
case x: RexFieldAccess => x.getReferenceExpr match {
case y: RexLocalRef if isPythonCall(program.expandLocalRef(y)) => true
case _ => false
}
case _ => false
}
}
override def split(program: RexProgram, splitter: ScalarFunctionSplitter)
: (Option[RexNode], Option[RexNode], Seq[RexNode]) = {
(Option(program.getCondition).map(program.expandLocalRef), None,
program.getProjectList.map(_.accept(splitter)))
}
}
private def containsFieldAccessAfterPythonCall(node: RexNode): Boolean = {
node match {
case call: RexCall => call.getOperands.exists(containsFieldAccessAfterPythonCall)
case x: RexFieldAccess => isPythonCall(x.getReferenceExpr)
case _ => false
}
/**
* Rule that splits the RexField with the input of Python function contained in the condition of
* [[FlinkLogicalCalc]]s.
*/
object PythonCalcSplitConditionRexFieldRule extends PythonCalcSplitRexFieldRuleBase(
"PythonCalcSplitConditionRexFieldRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc]
Option(calc.getProgram.getCondition)
.map(calc.getProgram.expandLocalRef).exists(containsFieldAccessAfterPythonCall)
}
override def split(program: RexProgram, splitter: ScalarFunctionSplitter)
: (Option[RexNode], Option[RexNode], Seq[RexNode]) = {
(None, Option(program.getCondition).map(_.accept(splitter)),
program.getProjectList.map(_.accept(splitter)))
}
}
......@@ -366,7 +391,7 @@ private class ScalarFunctionSplitter(
if (needConvert(fieldAccess)) {
val expr = fieldAccess.getReferenceExpr
expr match {
case localRef: RexLocalRef if isPythonCall(program.expandLocalRef(localRef))
case localRef: RexLocalRef if containsPythonCall(program.expandLocalRef(localRef))
=> getExtractedRexFieldAccess(fieldAccess, localRef.getIndex)
case _ => getExtractedRexNode(fieldAccess)
}
......@@ -455,7 +480,8 @@ object PythonCalcSplitRule {
val SPLIT_CONDITION: RelOptRule = PythonCalcSplitConditionRule
val SPLIT_PROJECT: RelOptRule = PythonCalcSplitProjectionRule
val SPLIT_PANDAS_IN_PROJECT: RelOptRule = PythonCalcSplitPandasInProjectionRule
val SPLIT_REX_FIELD: RelOptRule = PythonCalcSplitRexFieldRule
val SPLIT_PROJECTION_REX_FIELD: RelOptRule = PythonCalcSplitProjectionRexFieldRule
val SPLIT_CONDITION_REX_FIELD: RelOptRule = PythonCalcSplitConditionRexFieldRule
val EXPAND_PROJECT: RelOptRule = PythonCalcExpandProjectRule
val PUSH_CONDITION: RelOptRule = PythonCalcPushConditionRule
val REWRITE_PROJECT: RelOptRule = PythonCalcRewriteProjectionRule
......
......@@ -19,7 +19,7 @@
package org.apache.flink.table.planner.plan.utils
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.rex.{RexCall, RexFieldAccess, RexNode}
import org.apache.flink.table.functions.FunctionDefinition
import org.apache.flink.table.functions.python.{PythonFunction, PythonFunctionKind}
import org.apache.flink.table.planner.functions.aggfunctions.{DeclarativeAggregateFunction, InternalAggregateFunction}
......@@ -159,6 +159,10 @@ object PythonUtil {
(recursive && call.getOperands.exists(_.accept(this)))
}
override def visitFieldAccess(fieldAccess: RexFieldAccess): Boolean = {
fieldAccess.getReferenceExpr.accept(this)
}
override def visitNode(rexNode: RexNode): Boolean = false
}
}
......@@ -287,6 +287,32 @@ public class JavaUserDefinedScalarFunctions {
}
}
/**
* Test for Python Scalar Function.
*/
public static class RowJavaScalarFunction extends ScalarFunction {
private final String name;
public RowJavaScalarFunction(String name) {
this.name = name;
}
public Row eval(Object... a) {
return Row.of(1, 2);
}
@Override
public TypeInformation<?> getResultType(Class<?>[] signature) {
return Types.ROW(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);
}
@Override
public String toString() {
return name;
}
}
/**
* Test for Pandas Python Scalar Function.
*/
......
......@@ -427,6 +427,26 @@ FlinkLogicalCalc(select=[f0.f0 AS f0, f0.f1 AS f1])
]]>
</Resource>
</TestCase>
<TestCase name="testPythonFunctionWithCompositeWhereClause">
<Resource name="sql">
<![CDATA[SELECT a + 1 FROM MyTable where RowJavaFunc(pyFunc5(a).f0).f0 is NULL and b > 0]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(EXPR$0=[+($0, 1)])
+- LogicalFilter(condition=[AND(IS NULL(RowJavaFunc(pyFunc5($0).f0).f0), >($1, 0))])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalCalc(select=[+(a, 1) AS EXPR$0], where=[AND(IS NULL(f0.f0), >(b, 0))])
+- FlinkLogicalCalc(select=[a, b, RowJavaFunc(f0.f0) AS f0])
+- FlinkLogicalCalc(select=[a, b, pyFunc5(a) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
<TestCase name="testPythonFunctionWithCompositeInputsAndWhereClause">
<Resource name="sql">
<![CDATA[SELECT a, pyFunc1(b, d._1) FROM MyTable WHERE a + 1 > 0]]>
......
......@@ -23,7 +23,7 @@ import org.apache.flink.table.api._
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.optimize.program._
import org.apache.flink.table.planner.plan.rules.{FlinkBatchRuleSets, FlinkStreamRuleSets}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{BooleanPandasScalarFunction, BooleanPythonScalarFunction, PandasScalarFunction, PythonScalarFunction, RowPythonScalarFunction}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{BooleanPandasScalarFunction, BooleanPythonScalarFunction, PandasScalarFunction, PythonScalarFunction, RowJavaScalarFunction, RowPythonScalarFunction}
import org.apache.flink.table.planner.utils.TableTestBase
import org.apache.calcite.plan.hep.HepMatchOrder
......@@ -60,6 +60,7 @@ class PythonCalcSplitRuleTest extends TableTestBase {
util.addFunction("pyFunc3", new PythonScalarFunction("pyFunc3"))
util.addFunction("pyFunc4", new BooleanPythonScalarFunction("pyFunc4"))
util.addFunction("pyFunc5", new RowPythonScalarFunction("pyFunc5"))
util.addFunction("RowJavaFunc", new RowJavaScalarFunction("RowJavaFunc"))
util.addFunction("pandasFunc1", new PandasScalarFunction("pandasFunc1"))
util.addFunction("pandasFunc2", new PandasScalarFunction("pandasFunc2"))
util.addFunction("pandasFunc3", new PandasScalarFunction("pandasFunc3"))
......@@ -223,4 +224,10 @@ class PythonCalcSplitRuleTest extends TableTestBase {
val sqlQuery = "SELECT e.* FROM (SELECT pyFunc5(d._1) as e FROM MyTable) AS T"
util.verifyRelPlan(sqlQuery)
}
@Test
def testPythonFunctionWithCompositeWhereClause(): Unit = {
val sqlQuery = "SELECT a + 1 FROM MyTable where RowJavaFunc(pyFunc5(a).f0).f0 is NULL and b > 0"
util.verifyRelPlan(sqlQuery)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册