未验证 提交 110a0444 编写于 作者: J Jingsong Lee 提交者: GitHub

[FLINK-20745][table] Clean useless codes: Never push calcProgram to correlate

This closes #14474
上级 1aa9074e
...@@ -27,7 +27,6 @@ import org.apache.flink.table.types.logical.RowType; ...@@ -27,7 +27,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable; import javax.annotation.Nullable;
...@@ -38,7 +37,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec ...@@ -38,7 +37,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec
public BatchExecCorrelate( public BatchExecCorrelate(
FlinkJoinType joinType, FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation, RexCall invocation,
@Nullable RexNode condition, @Nullable RexNode condition,
ExecEdge inputEdge, ExecEdge inputEdge,
...@@ -46,7 +44,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec ...@@ -46,7 +44,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec
String description) { String description) {
super( super(
joinType, joinType,
project,
invocation, invocation,
condition, condition,
TableStreamOperator.class, TableStreamOperator.class,
......
...@@ -32,7 +32,6 @@ import org.apache.flink.table.types.logical.RowType; ...@@ -32,7 +32,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable; import javax.annotation.Nullable;
...@@ -44,8 +43,6 @@ import java.util.Optional; ...@@ -44,8 +43,6 @@ import java.util.Optional;
*/ */
public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> { public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
private final FlinkJoinType joinType; private final FlinkJoinType joinType;
@Nullable
private final RexProgram project;
private final RexCall invocation; private final RexCall invocation;
@Nullable @Nullable
private final RexNode condition; private final RexNode condition;
...@@ -54,7 +51,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> { ...@@ -54,7 +51,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
public CommonExecCorrelate( public CommonExecCorrelate(
FlinkJoinType joinType, FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation, RexCall invocation,
@Nullable RexNode condition, @Nullable RexNode condition,
Class<?> operatorBaseClass, Class<?> operatorBaseClass,
...@@ -64,7 +60,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> { ...@@ -64,7 +60,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
String description) { String description) {
super(Collections.singletonList(inputEdge), outputType, description); super(Collections.singletonList(inputEdge), outputType, description);
this.joinType = joinType; this.joinType = joinType;
this.project = project;
this.invocation = invocation; this.invocation = invocation;
this.condition = condition; this.condition = condition;
this.operatorBaseClass = operatorBaseClass; this.operatorBaseClass = operatorBaseClass;
...@@ -83,7 +78,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> { ...@@ -83,7 +78,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
ctx, ctx,
inputTransform, inputTransform,
(RowType) inputNode.getOutputType(), (RowType) inputNode.getOutputType(),
JavaScalaConversionUtil.toScala(Optional.ofNullable(project)),
invocation, invocation,
JavaScalaConversionUtil.toScala(Optional.ofNullable(condition)), JavaScalaConversionUtil.toScala(Optional.ofNullable(condition)),
(RowType) getOutputType(), (RowType) getOutputType(),
......
...@@ -28,7 +28,6 @@ import org.apache.flink.table.types.logical.RowType; ...@@ -28,7 +28,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable; import javax.annotation.Nullable;
...@@ -39,7 +38,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx ...@@ -39,7 +38,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx
public StreamExecCorrelate( public StreamExecCorrelate(
FlinkJoinType joinType, FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation, RexCall invocation,
@Nullable RexNode condition, @Nullable RexNode condition,
ExecEdge inputEdge, ExecEdge inputEdge,
...@@ -47,7 +45,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx ...@@ -47,7 +45,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx
String description) { String description) {
super( super(
joinType, joinType,
project,
invocation, invocation,
condition, condition,
AbstractProcessStreamOperator.class, AbstractProcessStreamOperator.class,
......
...@@ -118,7 +118,6 @@ public class BatchPhysicalPythonCorrelateRule extends ConverterRule { ...@@ -118,7 +118,6 @@ public class BatchPhysicalPythonCorrelateRule extends ConverterRule {
convInput, convInput,
scan, scan,
condition, condition,
null,
correlate.getRowType(), correlate.getRowType(),
correlate.getJoinType()); correlate.getJoinType());
} }
......
...@@ -124,7 +124,6 @@ public class StreamPhysicalPythonCorrelateRule extends ConverterRule { ...@@ -124,7 +124,6 @@ public class StreamPhysicalPythonCorrelateRule extends ConverterRule {
correlate.getCluster(), correlate.getCluster(),
traitSet, traitSet,
convInput, convInput,
null,
scan, scan,
condition, condition,
correlate.getRowType(), correlate.getRowType(),
......
...@@ -21,8 +21,8 @@ package org.apache.flink.table.planner.codegen ...@@ -21,8 +21,8 @@ package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.Function import org.apache.flink.api.common.functions.Function
import org.apache.flink.api.dag.Transformation import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.api.{TableConfig, TableException, ValidationException} import org.apache.flink.table.api.{TableConfig, TableException, ValidationException}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.data.utils.JoinedRowData import org.apache.flink.table.data.utils.JoinedRowData
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.functions.FunctionKind import org.apache.flink.table.functions.FunctionKind
import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.CodeGenUtils._
...@@ -37,8 +37,6 @@ import org.apache.flink.table.types.logical.RowType ...@@ -37,8 +37,6 @@ import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rex._ import org.apache.calcite.rex._
import scala.collection.JavaConversions._
object CorrelateCodeGenerator { object CorrelateCodeGenerator {
def generateCorrelateTransformation( def generateCorrelateTransformation(
...@@ -46,7 +44,6 @@ object CorrelateCodeGenerator { ...@@ -46,7 +44,6 @@ object CorrelateCodeGenerator {
operatorCtx: CodeGeneratorContext, operatorCtx: CodeGeneratorContext,
inputTransformation: Transformation[RowData], inputTransformation: Transformation[RowData],
inputType: RowType, inputType: RowType,
projectProgram: Option[RexProgram],
invocation: RexCall, invocation: RexCall,
condition: Option[RexNode], condition: Option[RexNode],
outputType: RowType, outputType: RowType,
...@@ -68,19 +65,6 @@ object CorrelateCodeGenerator { ...@@ -68,19 +65,6 @@ object CorrelateCodeGenerator {
s"Currently, only table functions can be used in a correlate operation.") s"Currently, only table functions can be used in a correlate operation.")
} }
val swallowInputOnly = if (projectProgram.isDefined) {
val program = projectProgram.get
val selects = program.getProjectList.map(_.getIndex)
val inputFieldCnt = program.getInputRowType.getFieldCount
val swallowInputOnly = selects.head > inputFieldCnt &&
(inputFieldCnt - outputType.getFieldCount == inputType.getFieldCount)
// partial output or output right only
swallowInputOnly
} else {
// completely output left input + right
false
}
// adjust indicies of InputRefs to adhere to schema expected by generator // adjust indicies of InputRefs to adhere to schema expected by generator
val changeInputRefIndexShuttle = new RexShuttle { val changeInputRefIndexShuttle = new RexShuttle {
override def visitInputRef(inputRef: RexInputRef): RexNode = { override def visitInputRef(inputRef: RexInputRef): RexNode = {
...@@ -92,8 +76,6 @@ object CorrelateCodeGenerator { ...@@ -92,8 +76,6 @@ object CorrelateCodeGenerator {
operatorCtx, operatorCtx,
config, config,
inputType, inputType,
projectProgram,
swallowInputOnly,
condition.map(_.accept(changeInputRefIndexShuttle)), condition.map(_.accept(changeInputRefIndexShuttle)),
outputType, outputType,
joinType, joinType,
...@@ -117,8 +99,6 @@ object CorrelateCodeGenerator { ...@@ -117,8 +99,6 @@ object CorrelateCodeGenerator {
ctx: CodeGeneratorContext, ctx: CodeGeneratorContext,
config: TableConfig, config: TableConfig,
inputType: RowType, inputType: RowType,
projectProgram: Option[RexProgram],
swallowInputOnly: Boolean = false,
condition: Option[RexNode], condition: Option[RexNode],
returnType: RowType, returnType: RowType,
joinType: FlinkJoinType, joinType: FlinkJoinType,
...@@ -136,8 +116,6 @@ object CorrelateCodeGenerator { ...@@ -136,8 +116,6 @@ object CorrelateCodeGenerator {
ctx, ctx,
config, config,
inputType, inputType,
projectProgram,
swallowInputOnly,
functionResultType, functionResultType,
returnType, returnType,
condition, condition,
...@@ -167,78 +145,27 @@ object CorrelateCodeGenerator { ...@@ -167,78 +145,27 @@ object CorrelateCodeGenerator {
// 3. left join // 3. left join
if (joinType == FlinkJoinType.LEFT) { if (joinType == FlinkJoinType.LEFT) {
if (swallowInputOnly) { // output all fields of left and right
// and the returned row table function is empty, collect a null // in case of left outer join and the returned row of table function is empty,
val nullRowTerm = CodeGenUtils.newName("nullRow") // fill all fields of row with null
ctx.addReusableOutputRecord(functionResultType, classOf[GenericRowData], nullRowTerm) val joinedRowTerm = CodeGenUtils.newName("joinedRow")
ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount) val nullRowTerm = CodeGenUtils.newName("nullRow")
val header = if (retainHeader) { ctx.addReusableOutputRecord(returnType, classOf[JoinedRowData], joinedRowTerm)
s"$nullRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());" ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount)
} else { val header = if (retainHeader) {
"" s"$joinedRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());"
}
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| $header
| $correlateCollectorTerm.outputResult($nullRowTerm);
|}
|""".stripMargin
} else if (projectProgram.isDefined) {
// output partial fields of left and right
val outputTerm = CodeGenUtils.newName("projectOut")
ctx.addReusableOutputRecord(returnType, classOf[GenericRowData], outputTerm)
val header = if (retainHeader) {
s"$outputTerm.setRowKind(${CodeGenUtils.DEFAULT_INPUT1_TERM}.getRowKind());"
} else {
""
}
val projectionExpression = generateProjectResultExpr(
ctx,
config,
inputType,
functionResultType,
udtfAlwaysNull = true,
returnType,
outputTerm,
projectProgram.get)
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| ${projectionExpression.code}
| $header
| $correlateCollectorTerm.outputResult($outputTerm);
|}
|""".stripMargin
} else { } else {
// output all fields of left and right ""
// in case of left outer join and the returned row of table function is empty, }
// fill all fields of row with null body +=
val joinedRowTerm = CodeGenUtils.newName("joinedRow") s"""
val nullRowTerm = CodeGenUtils.newName("nullRow") |boolean hasOutput = $correlateCollectorTerm.isCollected();
ctx.addReusableOutputRecord(returnType, classOf[JoinedRowData], joinedRowTerm) |if (!hasOutput) {
ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount) | $joinedRowTerm.replace(${exprGenerator.input1Term}, $nullRowTerm);
val header = if (retainHeader) { | $header
s"$joinedRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());" | $correlateCollectorTerm.outputResult($joinedRowTerm);
} else { |}
"" |""".stripMargin
}
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| $joinedRowTerm.replace(${exprGenerator.input1Term}, $nullRowTerm);
| $header
| $correlateCollectorTerm.outputResult($joinedRowTerm);
|}
|""".stripMargin
}
} else if (joinType != FlinkJoinType.INNER) { } else if (joinType != FlinkJoinType.INNER) {
throw new TableException(s"Unsupported JoinRelType: $joinType for correlate join.") throw new TableException(s"Unsupported JoinRelType: $joinType for correlate join.")
} }
...@@ -248,34 +175,6 @@ object CorrelateCodeGenerator { ...@@ -248,34 +175,6 @@ object CorrelateCodeGenerator {
new CodeGenOperatorFactory(genOperator) new CodeGenOperatorFactory(genOperator)
} }
private def generateProjectResultExpr(
ctx: CodeGeneratorContext,
config: TableConfig,
input1Type: RowType,
functionResultType: RowType,
udtfAlwaysNull: Boolean,
returnType: RowType,
outputTerm: String,
program: RexProgram): GeneratedExpression = {
val projectExprGenerator = new ExprCodeGenerator(ctx, udtfAlwaysNull)
.bindInput(input1Type, CodeGenUtils.DEFAULT_INPUT1_TERM)
if (udtfAlwaysNull) {
val udtfNullRow = CodeGenUtils.newName("udtfNullRow")
ctx.addReusableNullRow(udtfNullRow, functionResultType.getFieldCount)
projectExprGenerator.bindSecondInput(
functionResultType,
udtfNullRow)
} else {
projectExprGenerator.bindSecondInput(
functionResultType)
}
val projection = program.getProjectList.map(program.expandLocalRef)
val projectionExprs = projection.map(projectExprGenerator.generateExpression)
projectExprGenerator.generateResultExpression(
projectionExprs, returnType, classOf[GenericRowData], outputTerm)
}
/** /**
* Generates a collector that correlates input and converted table function results. Returns a * Generates a collector that correlates input and converted table function results. Returns a
* collector term for referencing the collector. * collector term for referencing the collector.
...@@ -284,8 +183,6 @@ object CorrelateCodeGenerator { ...@@ -284,8 +183,6 @@ object CorrelateCodeGenerator {
ctx: CodeGeneratorContext, ctx: CodeGeneratorContext,
config: TableConfig, config: TableConfig,
inputType: RowType, inputType: RowType,
projectProgram: Option[RexProgram],
swallowInputOnly: Boolean,
functionResultType: RowType, functionResultType: RowType,
resultType: RowType, resultType: RowType,
condition: Option[RexNode], condition: Option[RexNode],
...@@ -298,45 +195,7 @@ object CorrelateCodeGenerator { ...@@ -298,45 +195,7 @@ object CorrelateCodeGenerator {
val collectorCtx = CodeGeneratorContext(config) val collectorCtx = CodeGeneratorContext(config)
val body = if (projectProgram.isDefined) { val body = {
// partial output
if (swallowInputOnly) {
// output right only
val header = if (retainHeader) {
s"$udtfInputTerm.setRowKind($inputTerm.getRowKind());"
} else {
""
}
s"""
|$header
|outputResult($udtfInputTerm);
""".stripMargin
} else {
val outputTerm = CodeGenUtils.newName("projectOut")
collectorCtx.addReusableOutputRecord(resultType, classOf[GenericRowData], outputTerm)
val header = if (retainHeader) {
s"$outputTerm.setRowKind($inputTerm.getRowKind());"
} else {
""
}
val projectionExpression = generateProjectResultExpr(
collectorCtx,
config,
inputType,
functionResultType,
udtfAlwaysNull = false,
resultType,
outputTerm,
projectProgram.get)
s"""
|$header
|${projectionExpression.code}
|outputResult(${projectionExpression.resultTerm});
""".stripMargin
}
} else {
// completely output left input + right // completely output left input + right
val joinedRowTerm = CodeGenUtils.newName("joinedRow") val joinedRowTerm = CodeGenUtils.newName("joinedRow")
collectorCtx.addReusableOutputRecord(resultType, classOf[JoinedRowData], joinedRowTerm) collectorCtx.addReusableOutputRecord(resultType, classOf[JoinedRowData], joinedRowTerm)
......
...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType} import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
/** /**
* Batch physical RelNode for [[Correlate]] (Java/Scala user defined table function). * Batch physical RelNode for [[Correlate]] (Java/Scala user defined table function).
...@@ -38,7 +38,6 @@ class BatchPhysicalCorrelate( ...@@ -38,7 +38,6 @@ class BatchPhysicalCorrelate(
inputRel: RelNode, inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType, outputRowType: RelDataType,
joinType: JoinRelType) joinType: JoinRelType)
extends BatchPhysicalCorrelateBase( extends BatchPhysicalCorrelateBase(
...@@ -47,14 +46,12 @@ class BatchPhysicalCorrelate( ...@@ -47,14 +46,12 @@ class BatchPhysicalCorrelate(
inputRel, inputRel,
scan, scan,
condition, condition,
projectProgram,
outputRowType, outputRowType,
joinType) { joinType) {
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
child: RelNode, child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = { outputType: RelDataType): RelNode = {
new BatchPhysicalCorrelate( new BatchPhysicalCorrelate(
cluster, cluster,
...@@ -62,7 +59,6 @@ class BatchPhysicalCorrelate( ...@@ -62,7 +59,6 @@ class BatchPhysicalCorrelate(
child, child,
scan, scan,
condition, condition,
projectProgram,
outputType, outputType,
joinType) joinType)
} }
...@@ -70,7 +66,6 @@ class BatchPhysicalCorrelate( ...@@ -70,7 +66,6 @@ class BatchPhysicalCorrelate(
override def translateToExecNode(): ExecNode[_] = { override def translateToExecNode(): ExecNode[_] = {
new BatchExecCorrelate( new BatchExecCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType), JoinTypeUtil.getFlinkJoinType(joinType),
projectProgram.orNull,
scan.getCall.asInstanceOf[RexCall], scan.getCall.asInstanceOf[RexCall],
condition.orNull, condition.orNull,
ExecEdge.DEFAULT, ExecEdge.DEFAULT,
......
...@@ -26,8 +26,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} ...@@ -26,8 +26,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType} import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rel.{RelCollationTraitDef, RelDistribution, RelFieldCollation, RelNode, RelWriter, SingleRel} import org.apache.calcite.rel.{RelCollationTraitDef, RelDistribution, RelFieldCollation, RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.mapping.{Mapping, MappingType, Mappings} import org.apache.calcite.util.mapping.{Mapping, MappingType, Mappings}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -41,7 +40,6 @@ abstract class BatchPhysicalCorrelateBase( ...@@ -41,7 +40,6 @@ abstract class BatchPhysicalCorrelateBase(
inputRel: RelNode, inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType, outputRowType: RelDataType,
joinType: JoinRelType) joinType: JoinRelType)
extends SingleRel(cluster, traitSet, inputRel) extends SingleRel(cluster, traitSet, inputRel)
...@@ -52,7 +50,7 @@ abstract class BatchPhysicalCorrelateBase( ...@@ -52,7 +50,7 @@ abstract class BatchPhysicalCorrelateBase(
override def deriveRowType(): RelDataType = outputRowType override def deriveRowType(): RelDataType = outputRowType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
copy(traitSet, inputs.get(0), projectProgram, outputRowType) copy(traitSet, inputs.get(0), outputRowType)
} }
/** /**
...@@ -61,7 +59,6 @@ abstract class BatchPhysicalCorrelateBase( ...@@ -61,7 +59,6 @@ abstract class BatchPhysicalCorrelateBase(
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
child: RelNode, child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode outputType: RelDataType): RelNode
override def explainTerms(pw: RelWriter): RelWriter = { override def explainTerms(pw: RelWriter): RelWriter = {
...@@ -85,30 +82,11 @@ abstract class BatchPhysicalCorrelateBase( ...@@ -85,30 +82,11 @@ abstract class BatchPhysicalCorrelateBase(
def getOutputInputMapping: Mapping = { def getOutputInputMapping: Mapping = {
val inputFieldCnt = getInput.getRowType.getFieldCount val inputFieldCnt = getInput.getRowType.getFieldCount
projectProgram match { val mapping = Mappings.create(MappingType.FUNCTION, inputFieldCnt, inputFieldCnt)
case Some(program) => (0 until inputFieldCnt).foreach {
val projects = program.getProjectList.map(program.expandLocalRef) index => mapping.set(index, index)
val mapping = Mappings.create(MappingType.INVERSE_FUNCTION, inputFieldCnt, projects.size)
projects.zipWithIndex.foreach {
case (project, index) =>
project match {
case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index)
case call: RexCall if call.getKind == SqlKind.AS =>
call.getOperands.head match {
case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index)
case _ => // ignore
}
case _ => // ignore
}
}
mapping.inverse()
case _ =>
val mapping = Mappings.create(MappingType.FUNCTION, inputFieldCnt, inputFieldCnt)
(0 until inputFieldCnt).foreach {
index => mapping.set(index, index)
}
mapping
} }
mapping
} }
val mapping = getOutputInputMapping val mapping = getOutputInputMapping
......
...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType} import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
/** /**
* Batch physical RelNode for [[Correlate]] (Python user defined table function). * Batch physical RelNode for [[Correlate]] (Python user defined table function).
...@@ -38,7 +38,6 @@ class BatchPhysicalPythonCorrelate( ...@@ -38,7 +38,6 @@ class BatchPhysicalPythonCorrelate(
inputRel: RelNode, inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType, outputRowType: RelDataType,
joinType: JoinRelType) joinType: JoinRelType)
extends BatchPhysicalCorrelateBase( extends BatchPhysicalCorrelateBase(
...@@ -47,7 +46,6 @@ class BatchPhysicalPythonCorrelate( ...@@ -47,7 +46,6 @@ class BatchPhysicalPythonCorrelate(
inputRel, inputRel,
scan, scan,
condition, condition,
projectProgram,
outputRowType, outputRowType,
joinType) joinType)
with CommonPythonCorrelate { with CommonPythonCorrelate {
...@@ -55,7 +53,6 @@ class BatchPhysicalPythonCorrelate( ...@@ -55,7 +53,6 @@ class BatchPhysicalPythonCorrelate(
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
child: RelNode, child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = { outputType: RelDataType): RelNode = {
new BatchPhysicalPythonCorrelate( new BatchPhysicalPythonCorrelate(
cluster, cluster,
...@@ -63,7 +60,6 @@ class BatchPhysicalPythonCorrelate( ...@@ -63,7 +60,6 @@ class BatchPhysicalPythonCorrelate(
child, child,
scan, scan,
condition, condition,
projectProgram,
outputType, outputType,
joinType) joinType)
} }
......
...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
/** /**
* Flink RelNode which matches along with join a Java/Scala user defined table function. * Flink RelNode which matches along with join a Java/Scala user defined table function.
...@@ -36,7 +36,6 @@ class StreamPhysicalCorrelate( ...@@ -36,7 +36,6 @@ class StreamPhysicalCorrelate(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -45,7 +44,6 @@ class StreamPhysicalCorrelate( ...@@ -45,7 +44,6 @@ class StreamPhysicalCorrelate(
cluster, cluster,
traitSet, traitSet,
inputRel, inputRel,
projectProgram,
scan, scan,
condition, condition,
outputRowType, outputRowType,
...@@ -54,13 +52,11 @@ class StreamPhysicalCorrelate( ...@@ -54,13 +52,11 @@ class StreamPhysicalCorrelate(
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
newChild: RelNode, newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = { outputType: RelDataType): RelNode = {
new StreamPhysicalCorrelate( new StreamPhysicalCorrelate(
cluster, cluster,
traitSet, traitSet,
newChild, newChild,
projectProgram,
scan, scan,
condition, condition,
outputType, outputType,
...@@ -70,7 +66,6 @@ class StreamPhysicalCorrelate( ...@@ -70,7 +66,6 @@ class StreamPhysicalCorrelate(
override def translateToExecNode(): ExecNode[_] = { override def translateToExecNode(): ExecNode[_] = {
new StreamExecCorrelate( new StreamExecCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType), JoinTypeUtil.getFlinkJoinType(joinType),
projectProgram.orNull,
scan.getCall.asInstanceOf[RexCall], scan.getCall.asInstanceOf[RexCall],
condition.orNull, condition.orNull,
ExecEdge.DEFAULT, ExecEdge.DEFAULT,
......
...@@ -24,7 +24,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -24,7 +24,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -35,7 +35,6 @@ abstract class StreamPhysicalCorrelateBase( ...@@ -35,7 +35,6 @@ abstract class StreamPhysicalCorrelateBase(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
val projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -50,7 +49,7 @@ abstract class StreamPhysicalCorrelateBase( ...@@ -50,7 +49,7 @@ abstract class StreamPhysicalCorrelateBase(
override def deriveRowType(): RelDataType = outputRowType override def deriveRowType(): RelDataType = outputRowType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
copy(traitSet, inputs.get(0), projectProgram, outputRowType) copy(traitSet, inputs.get(0), outputRowType)
} }
/** /**
...@@ -59,7 +58,6 @@ abstract class StreamPhysicalCorrelateBase( ...@@ -59,7 +58,6 @@ abstract class StreamPhysicalCorrelateBase(
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
newChild: RelNode, newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode outputType: RelDataType): RelNode
override def explainTerms(pw: RelWriter): RelWriter = { override def explainTerms(pw: RelWriter): RelWriter = {
......
...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} ...@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} import org.apache.calcite.rex.{RexCall, RexNode}
/** /**
* Flink RelNode which matches along with join a python user defined table function. * Flink RelNode which matches along with join a python user defined table function.
...@@ -36,7 +36,6 @@ class StreamPhysicalPythonCorrelate( ...@@ -36,7 +36,6 @@ class StreamPhysicalPythonCorrelate(
cluster: RelOptCluster, cluster: RelOptCluster,
traitSet: RelTraitSet, traitSet: RelTraitSet,
inputRel: RelNode, inputRel: RelNode,
projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan, scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode], condition: Option[RexNode],
outputRowType: RelDataType, outputRowType: RelDataType,
...@@ -45,7 +44,6 @@ class StreamPhysicalPythonCorrelate( ...@@ -45,7 +44,6 @@ class StreamPhysicalPythonCorrelate(
cluster, cluster,
traitSet, traitSet,
inputRel, inputRel,
projectProgram,
scan, scan,
condition, condition,
outputRowType, outputRowType,
...@@ -55,13 +53,11 @@ class StreamPhysicalPythonCorrelate( ...@@ -55,13 +53,11 @@ class StreamPhysicalPythonCorrelate(
def copy( def copy(
traitSet: RelTraitSet, traitSet: RelTraitSet,
newChild: RelNode, newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = { outputType: RelDataType): RelNode = {
new StreamPhysicalPythonCorrelate( new StreamPhysicalPythonCorrelate(
cluster, cluster,
traitSet, traitSet,
newChild, newChild,
projectProgram,
scan, scan,
condition, condition,
outputType, outputType,
......
...@@ -72,7 +72,6 @@ class BatchPhysicalConstantTableFunctionScanRule ...@@ -72,7 +72,6 @@ class BatchPhysicalConstantTableFunctionScanRule
values, values,
scan, scan,
None, None,
None,
scan.getRowType, scan.getRowType,
JoinRelType.INNER) JoinRelType.INNER)
call.transformTo(correlate) call.transformTo(correlate)
......
...@@ -77,7 +77,6 @@ class BatchPhysicalCorrelateRule extends ConverterRule( ...@@ -77,7 +77,6 @@ class BatchPhysicalCorrelateRule extends ConverterRule(
convInput, convInput,
scan, scan,
condition, condition,
None,
rel.getRowType, rel.getRowType,
join.getJoinType) join.getJoinType)
} }
......
...@@ -70,7 +70,6 @@ class StreamPhysicalConstantTableFunctionScanRule ...@@ -70,7 +70,6 @@ class StreamPhysicalConstantTableFunctionScanRule
cluster, cluster,
traitSet, traitSet,
values, values,
None,
scan, scan,
None, None,
scan.getRowType, scan.getRowType,
......
...@@ -92,7 +92,6 @@ class StreamPhysicalCorrelateRule ...@@ -92,7 +92,6 @@ class StreamPhysicalCorrelateRule
rel.getCluster, rel.getCluster,
traitSet, traitSet,
convInput, convInput,
None,
scan, scan,
condition, condition,
rel.getRowType, rel.getRowType,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册