提交 7fd16039 编写于 作者: T Timo Walther

[FLINK-20522][table-planner-blink] Migrate InternalAggregateFunction to BuiltInAggregateFunction

上级 bd3bb365
......@@ -23,9 +23,12 @@ import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.data.GenericMapData;
import org.apache.flink.table.data.MapData;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
......@@ -34,7 +37,7 @@ import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataTyp
/** Built-in COLLECT aggregate function. */
@Internal
public final class CollectAggFunction<T>
extends InternalAggregateFunction<MapData, CollectAggFunction.CollectAccumulator<T>> {
extends BuiltInAggregateFunction<MapData, CollectAggFunction.CollectAccumulator<T>> {
private static final long serialVersionUID = -5860934997657147836L;
......@@ -49,8 +52,8 @@ public final class CollectAggFunction<T>
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {elementDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(elementDataType);
}
@Override
......
......@@ -24,14 +24,18 @@ import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.Collections;
import java.util.List;
import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType;
/** Built-in FIRST_VALUE aggregate function. */
@Internal
public final class FirstValueAggFunction<T> extends InternalAggregateFunction<T, RowData> {
public final class FirstValueAggFunction<T> extends BuiltInAggregateFunction<T, RowData> {
private transient DataType valueDataType;
......@@ -44,8 +48,8 @@ public final class FirstValueAggFunction<T> extends InternalAggregateFunction<T,
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
......@@ -23,10 +23,12 @@ import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
......@@ -36,7 +38,7 @@ import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataTyp
/** Built-in FIRST_VALUE with retraction aggregate function. */
@Internal
public final class FirstValueWithRetractAggFunction<T>
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
T, FirstValueWithRetractAggFunction.FirstValueWithRetractAccumulator<T>> {
private transient DataType valueDataType;
......@@ -50,8 +52,8 @@ public final class FirstValueWithRetractAggFunction<T>
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.functions.aggfunctions;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.planner.plan.utils.AggFunctionFactory;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.TypeInference;
import static org.apache.flink.table.types.inference.TypeStrategies.explicit;
/**
* Base class for fully resolved and strongly typed {@link AggregateFunction}s provided by {@link
* AggFunctionFactory}.
*
* <p>We override {@link #getTypeInference(DataTypeFactory)} in case the internal function is used
* externally for testing.
*/
@Internal
public abstract class InternalAggregateFunction<T, ACC> extends AggregateFunction<T, ACC> {
public abstract DataType[] getInputDataTypes();
public abstract DataType getAccumulatorDataType();
public abstract DataType getOutputDataType();
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
return TypeInference.newBuilder()
.typedArguments(getInputDataTypes())
.accumulatorTypeStrategy(explicit(getAccumulatorDataType()))
.outputTypeStrategy(explicit(getOutputDataType()))
.build();
}
}
......@@ -24,14 +24,18 @@ import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.Collections;
import java.util.List;
import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType;
/** Built-in LAST_VALUE aggregate function. */
@Internal
public final class LastValueAggFunction<T> extends InternalAggregateFunction<T, RowData> {
public final class LastValueAggFunction<T> extends BuiltInAggregateFunction<T, RowData> {
private transient DataType valueDataType;
......@@ -44,8 +48,8 @@ public final class LastValueAggFunction<T> extends InternalAggregateFunction<T,
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
......@@ -23,10 +23,12 @@ import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
......@@ -36,7 +38,7 @@ import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataTyp
/** Built-in LAST_VALUE with retraction aggregate function. */
@Internal
public final class LastValueWithRetractAggFunction<T>
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
T, LastValueWithRetractAggFunction.LastValueWithRetractAccumulator<T>> {
private transient DataType valueDataType;
......@@ -50,8 +52,8 @@ public final class LastValueWithRetractAggFunction<T>
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
......@@ -24,17 +24,19 @@ import org.apache.flink.table.api.dataview.ListView;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.data.binary.BinaryStringDataUtil;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.FlinkRuntimeException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/** Built-in LISTAGG with retraction aggregate function. */
@Internal
public final class ListAggWithRetractAggFunction
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
StringData, ListAggWithRetractAggFunction.ListAggWithRetractAccumulator> {
private static final long serialVersionUID = -2836795091288790955L;
......@@ -46,8 +48,8 @@ public final class ListAggWithRetractAggFunction
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {DataTypes.STRING().bridgedTo(StringData.class)};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(DataTypes.STRING().bridgedTo(StringData.class));
}
@Override
......
......@@ -24,17 +24,19 @@ import org.apache.flink.table.api.dataview.ListView;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.data.binary.BinaryStringDataUtil;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.FlinkRuntimeException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
/** Built-in LISTAGGWS with retraction aggregate function. */
@Internal
public final class ListAggWsWithRetractAggFunction
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
StringData, ListAggWsWithRetractAggFunction.ListAggWsWithRetractAccumulator> {
private static final long serialVersionUID = -8627988150350160473L;
......@@ -44,11 +46,10 @@ public final class ListAggWsWithRetractAggFunction
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {
DataTypes.STRING().bridgedTo(StringData.class),
DataTypes.STRING().bridgedTo(StringData.class)
};
public List<DataType> getArgumentDataTypes() {
return Arrays.asList(
DataTypes.STRING().bridgedTo(StringData.class),
DataTypes.STRING().bridgedTo(StringData.class));
}
@Override
......
......@@ -21,9 +21,12 @@ package org.apache.flink.table.planner.functions.aggfunctions;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
......@@ -32,7 +35,7 @@ import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataTyp
/** Built-in MAX with retraction aggregate function. */
@Internal
public final class MaxWithRetractAggFunction<T extends Comparable<T>>
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
T, MaxWithRetractAggFunction.MaxWithRetractAccumulator<T>> {
private static final long serialVersionUID = -5860934997657147836L;
......@@ -48,8 +51,8 @@ public final class MaxWithRetractAggFunction<T extends Comparable<T>>
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
......@@ -21,9 +21,12 @@ package org.apache.flink.table.planner.functions.aggfunctions;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
......@@ -32,7 +35,7 @@ import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataTyp
/** Built-in MIN with retraction aggregate function. */
@Internal
public final class MinWithRetractAggFunction<T extends Comparable<T>>
extends InternalAggregateFunction<
extends BuiltInAggregateFunction<
T, MinWithRetractAggFunction.MinWithRetractAccumulator<T>> {
private static final long serialVersionUID = 4253774292802374843L;
......@@ -48,8 +51,8 @@ public final class MinWithRetractAggFunction<T extends Comparable<T>>
// --------------------------------------------------------------------------------------------
@Override
public DataType[] getInputDataTypes() {
return new DataType[] {valueDataType};
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}
@Override
......
......@@ -28,6 +28,7 @@ import org.apache.flink.table.planner.functions.aggfunctions._
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction}
import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
......@@ -42,7 +43,7 @@ import scala.collection.JavaConversions._
/**
* Factory for creating runtime implementation for internal aggregate functions that are declared
* as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory
* returns [[DeclarativeAggregateFunction]] or [[InternalAggregateFunction]].
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
*
* @param inputType the input rel data type
* @param orderKeyIdx the indexes of order key (null when is not over agg)
......
......@@ -22,13 +22,13 @@ import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.expressions.ExpressionUtils.extractValue
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.{AggregateFunction, FunctionKind, ImperativeAggregateFunction, UserDefinedFunction, UserDefinedFunctionHelper}
import org.apache.flink.table.functions._
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, FlinkTypeSystem}
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart}
import org.apache.flink.table.planner.functions.aggfunctions.{DeclarativeAggregateFunction, InternalAggregateFunction}
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext
import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable, SqlFirstLastValueAggFunction, SqlListAggFunction}
......@@ -41,6 +41,7 @@ import org.apache.flink.table.planner.typeutils.DataViewUtils
import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec
import org.apache.flink.table.planner.typeutils.LegacyDataViewUtils.useNullSerializerForStateViewFieldsFromAccType
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction
import org.apache.flink.table.runtime.operators.bundle.trigger.CountBundleTrigger
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.types.DataType
......@@ -483,13 +484,13 @@ object AggregateUtil extends Enumeration {
hasStateBackedDataViews: Boolean)
: AggregateInfo = udf match {
case imperativeFunction: InternalAggregateFunction[_, _] =>
case imperativeFunction: BuiltInAggregateFunction[_, _] =>
createImperativeAggregateInfo(
call,
imperativeFunction,
index,
argIndexes,
imperativeFunction.getInputDataTypes,
imperativeFunction.getArgumentDataTypes.asScala.toArray,
imperativeFunction.getAccumulatorDataType,
imperativeFunction.getOutputDataType,
needsRetraction,
......
......@@ -22,9 +22,10 @@ import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex.{RexCall, RexFieldAccess, RexNode}
import org.apache.flink.table.functions.FunctionDefinition
import org.apache.flink.table.functions.python.{PythonFunction, PythonFunctionKind}
import org.apache.flink.table.planner.functions.aggfunctions.{DeclarativeAggregateFunction, InternalAggregateFunction}
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.functions.bridging.{BridgingSqlAggFunction, BridgingSqlFunction}
import org.apache.flink.table.planner.functions.utils.{AggSqlFunction, ScalarSqlFunction, TableSqlFunction}
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction
import scala.collection.JavaConversions._
......@@ -95,7 +96,7 @@ object PythonUtil {
val aggregation = call.getAggregation
aggregation match {
case function: AggSqlFunction =>
function.aggregateFunction.isInstanceOf[InternalAggregateFunction[_, _]]
function.aggregateFunction.isInstanceOf[BuiltInAggregateFunction[_, _]]
case function: BridgingSqlAggFunction =>
function.getDefinition.isInstanceOf[DeclarativeAggregateFunction]
case _ => true
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册