提交 049a61a5 编写于 作者: G godfreyhe

[FLINK-20738][table-planner-blink] Introduce...

[FLINK-20738][table-planner-blink] Introduce BatchPhysicalPythonGroupAggregate, and make BatchExecPythonGroupAggregate only extended from ExecNode

This closes #14562
上级 63875b3c
......@@ -24,7 +24,7 @@ import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalPythonGroupAggregate;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil;
......@@ -50,18 +50,18 @@ import scala.collection.Seq;
/**
* The physical rule which is responsible for converting {@link FlinkLogicalAggregate} to {@link
* BatchExecPythonGroupAggregate}.
* BatchPhysicalPythonGroupAggregate}.
*/
public class BatchExecPythonAggregateRule extends ConverterRule {
public class BatchPhysicalPythonAggregateRule extends ConverterRule {
public static final RelOptRule INSTANCE = new BatchExecPythonAggregateRule();
public static final RelOptRule INSTANCE = new BatchPhysicalPythonAggregateRule();
private BatchExecPythonAggregateRule() {
private BatchPhysicalPythonAggregateRule() {
super(
FlinkLogicalAggregate.class,
FlinkConventions.LOGICAL(),
FlinkConventions.BATCH_PHYSICAL(),
"BatchExecPythonAggregateRule");
"BatchPhysicalPythonAggregateRule");
}
@Override
......@@ -124,7 +124,7 @@ public class BatchExecPythonAggregateRule extends ConverterRule {
}
RelNode convInput = RelOptRule.convert(input, requiredTraitSet);
return new BatchExecPythonGroupAggregate(
return new BatchPhysicalPythonGroupAggregate(
relNode.getCluster(),
traitSet,
convInput,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.Configuration
import org.apache.flink.core.memory.ManagedMemoryUseCase
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.python.PythonFunctionInfo
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonGroupAggregate.ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode, ExecNodeBase}
import org.apache.flink.table.planner.utils.Logging
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rel.core.AggregateCall
import java.util.Collections
/**
* Batch [[ExecNode]] for aggregate (Python user defined aggregate function).
*
* <p>Note: This class can't be ported to Java,
* because java class can't extend scala interface with default implementation.
* FLINK-20751 will port this class to Java.
*/
class BatchExecPythonGroupAggregate(
grouping: Array[Int],
auxGrouping: Array[Int],
aggCalls: Seq[AggregateCall],
inputEdge: ExecEdge,
outputType: RowType,
description: String)
extends ExecNodeBase[RowData](Collections.singletonList(inputEdge), outputType, description)
with BatchExecNode[RowData]
with CommonExecPythonAggregate
with Logging {
override protected def translateToPlanInternal(
planner: PlannerBase): Transformation[RowData] = {
val inputNode = getInputNodes.get(0).asInstanceOf[ExecNode[RowData]]
val inputTransform = inputNode.translateToPlan(planner)
val ret = createPythonOneInputTransformation(
inputTransform,
inputNode.getOutputType.asInstanceOf[RowType],
outputType,
getConfig(planner.getExecEnv, planner.getTableConfig))
if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
}
ret
}
private[this] def createPythonOneInputTransformation(
inputTransform: Transformation[RowData],
inputRowType: RowType,
outputRowType: RowType,
config: Configuration): OneInputTransformation[RowData, RowData] = {
val (pythonUdafInputOffsets, pythonFunctionInfos) =
extractPythonAggregateFunctionInfosFromAggregateCall(aggCalls)
val pythonOperator = getPythonAggregateFunctionOperator(
config,
inputRowType,
outputRowType,
pythonUdafInputOffsets,
pythonFunctionInfos)
new OneInputTransformation(
inputTransform,
"BatchExecPythonGroupAggregate",
pythonOperator,
InternalTypeInfo.of(outputRowType),
inputTransform.getParallelism)
}
private[this] def getPythonAggregateFunctionOperator(
config: Configuration,
inputRowType: RowType,
outputRowType: RowType,
udafInputOffsets: Array[Int],
pythonFunctionInfos: Array[PythonFunctionInfo]): OneInputStreamOperator[RowData, RowData] = {
val clazz = loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME)
val ctor = clazz.getConstructor(
classOf[Configuration],
classOf[Array[PythonFunctionInfo]],
classOf[RowType],
classOf[RowType],
classOf[Array[Int]],
classOf[Array[Int]],
classOf[Array[Int]])
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowType,
outputRowType,
grouping,
grouping ++ auxGrouping,
udafInputOffsets).asInstanceOf[OneInputStreamOperator[RowData, RowData]]
}
}
object BatchExecPythonGroupAggregate {
val ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME: String =
"org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch." +
"BatchArrowPythonGroupAggregateFunctionOperator"
}
......@@ -18,24 +18,13 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.Configuration
import org.apache.flink.core.memory.ManagedMemoryUseCase
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.functions.python.PythonFunctionInfo
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode}
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate.ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonGroupAggregate
import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase
import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil}
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON}
......@@ -51,7 +40,7 @@ import scala.collection.JavaConversions._
/**
* Batch physical RelNode for aggregate (Python user defined aggregate function).
*/
class BatchExecPythonGroupAggregate(
class BatchPhysicalPythonGroupAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
......@@ -71,9 +60,7 @@ class BatchExecPythonGroupAggregate(
auxGrouping,
aggCalls.zip(aggFunctions),
isMerge = false,
isFinal = true)
with LegacyBatchExecNode[RowData]
with CommonExecPythonAggregate {
isFinal = true) {
override def explainTerms(pw: RelWriter): RelWriter =
super.explainTerms(pw)
......@@ -149,7 +136,7 @@ class BatchExecPythonGroupAggregate(
}
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecPythonGroupAggregate(
new BatchPhysicalPythonGroupAggregate(
cluster,
traitSet,
inputs.get(0),
......@@ -162,84 +149,15 @@ class BatchExecPythonGroupAggregate(
aggFunctions)
}
//~ ExecNode methods -----------------------------------------------------------
override def getInputEdges: util.List[ExecEdge] = List(
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())
override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
val ret = createPythonOneInputTransformation(
input,
inputType,
outputType,
getConfig(planner.getExecEnv, planner.getTableConfig))
if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
}
ret
}
private[this] def createPythonOneInputTransformation(
inputTransform: Transformation[RowData],
inputRowType: RowType,
outputRowType: RowType,
config: Configuration): OneInputTransformation[RowData, RowData] = {
val (pythonUdafInputOffsets, pythonFunctionInfos) =
extractPythonAggregateFunctionInfosFromAggregateCall(aggCalls)
val pythonOperator = getPythonAggregateFunctionOperator(
config,
inputRowType,
outputRowType,
pythonUdafInputOffsets,
pythonFunctionInfos)
new OneInputTransformation(
inputTransform,
"BatchExecPythonGroupAggregate",
pythonOperator,
InternalTypeInfo.of(outputRowType),
inputTransform.getParallelism)
}
private[this] def getPythonAggregateFunctionOperator(
config: Configuration,
inputRowType: RowType,
outputRowType: RowType,
udafInputOffsets: Array[Int],
pythonFunctionInfos: Array[PythonFunctionInfo]): OneInputStreamOperator[RowData, RowData] = {
val clazz = loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME)
val ctor = clazz.getConstructor(
classOf[Configuration],
classOf[Array[PythonFunctionInfo]],
classOf[RowType],
classOf[RowType],
classOf[Array[Int]],
classOf[Array[Int]],
classOf[Array[Int]])
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowType,
outputRowType,
override def translateToExecNode(): ExecNode[_] = {
new BatchExecPythonGroupAggregate(
grouping,
grouping ++ auxGrouping,
udafInputOffsets).asInstanceOf[OneInputStreamOperator[RowData, RowData]]
auxGrouping,
aggCalls,
ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build(),
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
)
}
}
object BatchExecPythonGroupAggregate {
val ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME: String =
"org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch." +
"BatchArrowPythonGroupAggregateFunctionOperator"
}
......@@ -419,7 +419,7 @@ object FlinkBatchRuleSets {
RemoveRedundantLocalSortAggRule.WITHOUT_SORT,
RemoveRedundantLocalSortAggRule.WITH_SORT,
RemoveRedundantLocalHashAggRule.INSTANCE,
BatchExecPythonAggregateRule.INSTANCE,
BatchPhysicalPythonAggregateRule.INSTANCE,
// over agg
BatchExecOverAggregateRule.INSTANCE,
// window agg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册