未验证 提交 110a0444 编写于 作者: J Jingsong Lee 提交者: GitHub

[FLINK-20745][table] Clean useless codes: Never push calcProgram to correlate

This closes #14474
上级 1aa9074e
......@@ -27,7 +27,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable;
......@@ -38,7 +37,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec
public BatchExecCorrelate(
FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation,
@Nullable RexNode condition,
ExecEdge inputEdge,
......@@ -46,7 +44,6 @@ public class BatchExecCorrelate extends CommonExecCorrelate implements BatchExec
String description) {
super(
joinType,
project,
invocation,
condition,
TableStreamOperator.class,
......
......@@ -32,7 +32,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable;
......@@ -44,8 +43,6 @@ import java.util.Optional;
*/
public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
private final FlinkJoinType joinType;
@Nullable
private final RexProgram project;
private final RexCall invocation;
@Nullable
private final RexNode condition;
......@@ -54,7 +51,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
public CommonExecCorrelate(
FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation,
@Nullable RexNode condition,
Class<?> operatorBaseClass,
......@@ -64,7 +60,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.joinType = joinType;
this.project = project;
this.invocation = invocation;
this.condition = condition;
this.operatorBaseClass = operatorBaseClass;
......@@ -83,7 +78,6 @@ public abstract class CommonExecCorrelate extends ExecNodeBase<RowData> {
ctx,
inputTransform,
(RowType) inputNode.getOutputType(),
JavaScalaConversionUtil.toScala(Optional.ofNullable(project)),
invocation,
JavaScalaConversionUtil.toScala(Optional.ofNullable(condition)),
(RowType) getOutputType(),
......
......@@ -28,7 +28,6 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import javax.annotation.Nullable;
......@@ -39,7 +38,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx
public StreamExecCorrelate(
FlinkJoinType joinType,
@Nullable RexProgram project,
RexCall invocation,
@Nullable RexNode condition,
ExecEdge inputEdge,
......@@ -47,7 +45,6 @@ public class StreamExecCorrelate extends CommonExecCorrelate implements StreamEx
String description) {
super(
joinType,
project,
invocation,
condition,
AbstractProcessStreamOperator.class,
......
......@@ -118,7 +118,6 @@ public class BatchPhysicalPythonCorrelateRule extends ConverterRule {
convInput,
scan,
condition,
null,
correlate.getRowType(),
correlate.getJoinType());
}
......
......@@ -124,7 +124,6 @@ public class StreamPhysicalPythonCorrelateRule extends ConverterRule {
correlate.getCluster(),
traitSet,
convInput,
null,
scan,
condition,
correlate.getRowType(),
......
......@@ -21,8 +21,8 @@ package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.Function
import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.api.{TableConfig, TableException, ValidationException}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.data.utils.JoinedRowData
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.functions.FunctionKind
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGenUtils._
......@@ -37,8 +37,6 @@ import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rex._
import scala.collection.JavaConversions._
object CorrelateCodeGenerator {
def generateCorrelateTransformation(
......@@ -46,7 +44,6 @@ object CorrelateCodeGenerator {
operatorCtx: CodeGeneratorContext,
inputTransformation: Transformation[RowData],
inputType: RowType,
projectProgram: Option[RexProgram],
invocation: RexCall,
condition: Option[RexNode],
outputType: RowType,
......@@ -68,19 +65,6 @@ object CorrelateCodeGenerator {
s"Currently, only table functions can be used in a correlate operation.")
}
val swallowInputOnly = if (projectProgram.isDefined) {
val program = projectProgram.get
val selects = program.getProjectList.map(_.getIndex)
val inputFieldCnt = program.getInputRowType.getFieldCount
val swallowInputOnly = selects.head > inputFieldCnt &&
(inputFieldCnt - outputType.getFieldCount == inputType.getFieldCount)
// partial output or output right only
swallowInputOnly
} else {
// completely output left input + right
false
}
// adjust indicies of InputRefs to adhere to schema expected by generator
val changeInputRefIndexShuttle = new RexShuttle {
override def visitInputRef(inputRef: RexInputRef): RexNode = {
......@@ -92,8 +76,6 @@ object CorrelateCodeGenerator {
operatorCtx,
config,
inputType,
projectProgram,
swallowInputOnly,
condition.map(_.accept(changeInputRefIndexShuttle)),
outputType,
joinType,
......@@ -117,8 +99,6 @@ object CorrelateCodeGenerator {
ctx: CodeGeneratorContext,
config: TableConfig,
inputType: RowType,
projectProgram: Option[RexProgram],
swallowInputOnly: Boolean = false,
condition: Option[RexNode],
returnType: RowType,
joinType: FlinkJoinType,
......@@ -136,8 +116,6 @@ object CorrelateCodeGenerator {
ctx,
config,
inputType,
projectProgram,
swallowInputOnly,
functionResultType,
returnType,
condition,
......@@ -167,78 +145,27 @@ object CorrelateCodeGenerator {
// 3. left join
if (joinType == FlinkJoinType.LEFT) {
if (swallowInputOnly) {
// and the returned row table function is empty, collect a null
val nullRowTerm = CodeGenUtils.newName("nullRow")
ctx.addReusableOutputRecord(functionResultType, classOf[GenericRowData], nullRowTerm)
ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount)
val header = if (retainHeader) {
s"$nullRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());"
} else {
""
}
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| $header
| $correlateCollectorTerm.outputResult($nullRowTerm);
|}
|""".stripMargin
} else if (projectProgram.isDefined) {
// output partial fields of left and right
val outputTerm = CodeGenUtils.newName("projectOut")
ctx.addReusableOutputRecord(returnType, classOf[GenericRowData], outputTerm)
val header = if (retainHeader) {
s"$outputTerm.setRowKind(${CodeGenUtils.DEFAULT_INPUT1_TERM}.getRowKind());"
} else {
""
}
val projectionExpression = generateProjectResultExpr(
ctx,
config,
inputType,
functionResultType,
udtfAlwaysNull = true,
returnType,
outputTerm,
projectProgram.get)
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| ${projectionExpression.code}
| $header
| $correlateCollectorTerm.outputResult($outputTerm);
|}
|""".stripMargin
// output all fields of left and right
// in case of left outer join and the returned row of table function is empty,
// fill all fields of row with null
val joinedRowTerm = CodeGenUtils.newName("joinedRow")
val nullRowTerm = CodeGenUtils.newName("nullRow")
ctx.addReusableOutputRecord(returnType, classOf[JoinedRowData], joinedRowTerm)
ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount)
val header = if (retainHeader) {
s"$joinedRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());"
} else {
// output all fields of left and right
// in case of left outer join and the returned row of table function is empty,
// fill all fields of row with null
val joinedRowTerm = CodeGenUtils.newName("joinedRow")
val nullRowTerm = CodeGenUtils.newName("nullRow")
ctx.addReusableOutputRecord(returnType, classOf[JoinedRowData], joinedRowTerm)
ctx.addReusableNullRow(nullRowTerm, functionResultType.getFieldCount)
val header = if (retainHeader) {
s"$joinedRowTerm.setRowKind(${exprGenerator.input1Term}.getRowKind());"
} else {
""
}
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| $joinedRowTerm.replace(${exprGenerator.input1Term}, $nullRowTerm);
| $header
| $correlateCollectorTerm.outputResult($joinedRowTerm);
|}
|""".stripMargin
}
""
}
body +=
s"""
|boolean hasOutput = $correlateCollectorTerm.isCollected();
|if (!hasOutput) {
| $joinedRowTerm.replace(${exprGenerator.input1Term}, $nullRowTerm);
| $header
| $correlateCollectorTerm.outputResult($joinedRowTerm);
|}
|""".stripMargin
} else if (joinType != FlinkJoinType.INNER) {
throw new TableException(s"Unsupported JoinRelType: $joinType for correlate join.")
}
......@@ -248,34 +175,6 @@ object CorrelateCodeGenerator {
new CodeGenOperatorFactory(genOperator)
}
private def generateProjectResultExpr(
ctx: CodeGeneratorContext,
config: TableConfig,
input1Type: RowType,
functionResultType: RowType,
udtfAlwaysNull: Boolean,
returnType: RowType,
outputTerm: String,
program: RexProgram): GeneratedExpression = {
val projectExprGenerator = new ExprCodeGenerator(ctx, udtfAlwaysNull)
.bindInput(input1Type, CodeGenUtils.DEFAULT_INPUT1_TERM)
if (udtfAlwaysNull) {
val udtfNullRow = CodeGenUtils.newName("udtfNullRow")
ctx.addReusableNullRow(udtfNullRow, functionResultType.getFieldCount)
projectExprGenerator.bindSecondInput(
functionResultType,
udtfNullRow)
} else {
projectExprGenerator.bindSecondInput(
functionResultType)
}
val projection = program.getProjectList.map(program.expandLocalRef)
val projectionExprs = projection.map(projectExprGenerator.generateExpression)
projectExprGenerator.generateResultExpression(
projectionExprs, returnType, classOf[GenericRowData], outputTerm)
}
/**
* Generates a collector that correlates input and converted table function results. Returns a
* collector term for referencing the collector.
......@@ -284,8 +183,6 @@ object CorrelateCodeGenerator {
ctx: CodeGeneratorContext,
config: TableConfig,
inputType: RowType,
projectProgram: Option[RexProgram],
swallowInputOnly: Boolean,
functionResultType: RowType,
resultType: RowType,
condition: Option[RexNode],
......@@ -298,45 +195,7 @@ object CorrelateCodeGenerator {
val collectorCtx = CodeGeneratorContext(config)
val body = if (projectProgram.isDefined) {
// partial output
if (swallowInputOnly) {
// output right only
val header = if (retainHeader) {
s"$udtfInputTerm.setRowKind($inputTerm.getRowKind());"
} else {
""
}
s"""
|$header
|outputResult($udtfInputTerm);
""".stripMargin
} else {
val outputTerm = CodeGenUtils.newName("projectOut")
collectorCtx.addReusableOutputRecord(resultType, classOf[GenericRowData], outputTerm)
val header = if (retainHeader) {
s"$outputTerm.setRowKind($inputTerm.getRowKind());"
} else {
""
}
val projectionExpression = generateProjectResultExpr(
collectorCtx,
config,
inputType,
functionResultType,
udtfAlwaysNull = false,
resultType,
outputTerm,
projectProgram.get)
s"""
|$header
|${projectionExpression.code}
|outputResult(${projectionExpression.resultTerm});
""".stripMargin
}
} else {
val body = {
// completely output left input + right
val joinedRowTerm = CodeGenUtils.newName("joinedRow")
collectorCtx.addReusableOutputRecord(resultType, classOf[JoinedRowData], joinedRowTerm)
......
......@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram}
import org.apache.calcite.rex.{RexCall, RexNode}
/**
* Batch physical RelNode for [[Correlate]] (Java/Scala user defined table function).
......@@ -38,7 +38,6 @@ class BatchPhysicalCorrelate(
inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType,
joinType: JoinRelType)
extends BatchPhysicalCorrelateBase(
......@@ -47,14 +46,12 @@ class BatchPhysicalCorrelate(
inputRel,
scan,
condition,
projectProgram,
outputRowType,
joinType) {
def copy(
traitSet: RelTraitSet,
child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = {
new BatchPhysicalCorrelate(
cluster,
......@@ -62,7 +59,6 @@ class BatchPhysicalCorrelate(
child,
scan,
condition,
projectProgram,
outputType,
joinType)
}
......@@ -70,7 +66,6 @@ class BatchPhysicalCorrelate(
override def translateToExecNode(): ExecNode[_] = {
new BatchExecCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType),
projectProgram.orNull,
scan.getCall.asInstanceOf[RexCall],
condition.orNull,
ExecEdge.DEFAULT,
......
......@@ -26,8 +26,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rel.{RelCollationTraitDef, RelDistribution, RelFieldCollation, RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexProgram}
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.util.mapping.{Mapping, MappingType, Mappings}
import scala.collection.JavaConversions._
......@@ -41,7 +40,6 @@ abstract class BatchPhysicalCorrelateBase(
inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType,
joinType: JoinRelType)
extends SingleRel(cluster, traitSet, inputRel)
......@@ -52,7 +50,7 @@ abstract class BatchPhysicalCorrelateBase(
override def deriveRowType(): RelDataType = outputRowType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
copy(traitSet, inputs.get(0), projectProgram, outputRowType)
copy(traitSet, inputs.get(0), outputRowType)
}
/**
......@@ -61,7 +59,6 @@ abstract class BatchPhysicalCorrelateBase(
def copy(
traitSet: RelTraitSet,
child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode
override def explainTerms(pw: RelWriter): RelWriter = {
......@@ -85,30 +82,11 @@ abstract class BatchPhysicalCorrelateBase(
def getOutputInputMapping: Mapping = {
val inputFieldCnt = getInput.getRowType.getFieldCount
projectProgram match {
case Some(program) =>
val projects = program.getProjectList.map(program.expandLocalRef)
val mapping = Mappings.create(MappingType.INVERSE_FUNCTION, inputFieldCnt, projects.size)
projects.zipWithIndex.foreach {
case (project, index) =>
project match {
case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index)
case call: RexCall if call.getKind == SqlKind.AS =>
call.getOperands.head match {
case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index)
case _ => // ignore
}
case _ => // ignore
}
}
mapping.inverse()
case _ =>
val mapping = Mappings.create(MappingType.FUNCTION, inputFieldCnt, inputFieldCnt)
(0 until inputFieldCnt).foreach {
index => mapping.set(index, index)
}
mapping
val mapping = Mappings.create(MappingType.FUNCTION, inputFieldCnt, inputFieldCnt)
(0 until inputFieldCnt).foreach {
index => mapping.set(index, index)
}
mapping
}
val mapping = getOutputInputMapping
......
......@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Correlate, JoinRelType}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram}
import org.apache.calcite.rex.{RexCall, RexNode}
/**
* Batch physical RelNode for [[Correlate]] (Python user defined table function).
......@@ -38,7 +38,6 @@ class BatchPhysicalPythonCorrelate(
inputRel: RelNode,
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
projectProgram: Option[RexProgram],
outputRowType: RelDataType,
joinType: JoinRelType)
extends BatchPhysicalCorrelateBase(
......@@ -47,7 +46,6 @@ class BatchPhysicalPythonCorrelate(
inputRel,
scan,
condition,
projectProgram,
outputRowType,
joinType)
with CommonPythonCorrelate {
......@@ -55,7 +53,6 @@ class BatchPhysicalPythonCorrelate(
def copy(
traitSet: RelTraitSet,
child: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = {
new BatchPhysicalPythonCorrelate(
cluster,
......@@ -63,7 +60,6 @@ class BatchPhysicalPythonCorrelate(
child,
scan,
condition,
projectProgram,
outputType,
joinType)
}
......
......@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram}
import org.apache.calcite.rex.{RexCall, RexNode}
/**
* Flink RelNode which matches along with join a Java/Scala user defined table function.
......@@ -36,7 +36,6 @@ class StreamPhysicalCorrelate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
outputRowType: RelDataType,
......@@ -45,7 +44,6 @@ class StreamPhysicalCorrelate(
cluster,
traitSet,
inputRel,
projectProgram,
scan,
condition,
outputRowType,
......@@ -54,13 +52,11 @@ class StreamPhysicalCorrelate(
def copy(
traitSet: RelTraitSet,
newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = {
new StreamPhysicalCorrelate(
cluster,
traitSet,
newChild,
projectProgram,
scan,
condition,
outputType,
......@@ -70,7 +66,6 @@ class StreamPhysicalCorrelate(
override def translateToExecNode(): ExecNode[_] = {
new StreamExecCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType),
projectProgram.orNull,
scan.getCall.asInstanceOf[RexCall],
condition.orNull,
ExecEdge.DEFAULT,
......
......@@ -24,7 +24,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram}
import org.apache.calcite.rex.{RexCall, RexNode}
import scala.collection.JavaConversions._
......@@ -35,7 +35,6 @@ abstract class StreamPhysicalCorrelateBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
val projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
outputRowType: RelDataType,
......@@ -50,7 +49,7 @@ abstract class StreamPhysicalCorrelateBase(
override def deriveRowType(): RelDataType = outputRowType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
copy(traitSet, inputs.get(0), projectProgram, outputRowType)
copy(traitSet, inputs.get(0), outputRowType)
}
/**
......@@ -59,7 +58,6 @@ abstract class StreamPhysicalCorrelateBase(
def copy(
traitSet: RelTraitSet,
newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode
override def explainTerms(pw: RelWriter): RelWriter = {
......
......@@ -27,7 +27,7 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode, RexProgram}
import org.apache.calcite.rex.{RexCall, RexNode}
/**
* Flink RelNode which matches along with join a python user defined table function.
......@@ -36,7 +36,6 @@ class StreamPhysicalPythonCorrelate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
projectProgram: Option[RexProgram],
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
outputRowType: RelDataType,
......@@ -45,7 +44,6 @@ class StreamPhysicalPythonCorrelate(
cluster,
traitSet,
inputRel,
projectProgram,
scan,
condition,
outputRowType,
......@@ -55,13 +53,11 @@ class StreamPhysicalPythonCorrelate(
def copy(
traitSet: RelTraitSet,
newChild: RelNode,
projectProgram: Option[RexProgram],
outputType: RelDataType): RelNode = {
new StreamPhysicalPythonCorrelate(
cluster,
traitSet,
newChild,
projectProgram,
scan,
condition,
outputType,
......
......@@ -72,7 +72,6 @@ class BatchPhysicalConstantTableFunctionScanRule
values,
scan,
None,
None,
scan.getRowType,
JoinRelType.INNER)
call.transformTo(correlate)
......
......@@ -77,7 +77,6 @@ class BatchPhysicalCorrelateRule extends ConverterRule(
convInput,
scan,
condition,
None,
rel.getRowType,
join.getJoinType)
}
......
......@@ -70,7 +70,6 @@ class StreamPhysicalConstantTableFunctionScanRule
cluster,
traitSet,
values,
None,
scan,
None,
scan.getRowType,
......
......@@ -92,7 +92,6 @@ class StreamPhysicalCorrelateRule
rel.getCluster,
traitSet,
convInput,
None,
scan,
condition,
rel.getRowType,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册