提交 91c90c5d 编写于 作者: S shaoxuan-wang 提交者: Fabian Hueske

[FLINK-5915] [table] Forward the complete aggregate ArgList to aggregate runtime functions.

This closes #3647.
上级 48890285
......@@ -33,7 +33,7 @@ import org.apache.flink.types.Row
*/
class AggregateAggFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int])
private val aggFields: Array[Array[Int]])
extends DataStreamAggFunc[Row, Row, Row] {
override def createAccumulator(): Row = {
......@@ -51,7 +51,7 @@ class AggregateAggFunction(
var i = 0
while (i < aggregates.length) {
val acc = accumulatorRow.getField(i).asInstanceOf[Accumulator]
val v = value.getField(aggFields(i))
val v = value.getField(aggFields(i)(0))
aggregates(i).accumulate(acc, v)
i += 1
}
......
......@@ -45,6 +45,7 @@ import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeIn
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
object AggregateUtil {
......@@ -886,10 +887,10 @@ object AggregateUtil {
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
needRetraction: Boolean)
: (Array[Int], Array[TableAggregateFunction[_ <: Any]]) = {
: (Array[Array[Int]], Array[TableAggregateFunction[_ <: Any]]) = {
// store the aggregate fields of each aggregate function, by the same order of aggregates.
val aggFieldIndexes = new Array[Int](aggregateCalls.size)
val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size)
val aggregates = new Array[TableAggregateFunction[_ <: Any]](aggregateCalls.size)
// create aggregate function instances by function type and aggregate field data type.
......@@ -897,7 +898,7 @@ object AggregateUtil {
val argList: util.List[Integer] = aggregateCall.getArgList
if (argList.isEmpty) {
if (aggregateCall.getAggregation.isInstanceOf[SqlCountAggFunction]) {
aggFieldIndexes(index) = 0
aggFieldIndexes(index) = Array[Int](0)
} else {
throw new TableException("Aggregate fields should not be empty.")
}
......@@ -905,9 +906,10 @@ object AggregateUtil {
if (argList.size() > 1) {
throw new TableException("Currently, do not support aggregate on multi fields.")
}
aggFieldIndexes(index) = argList.get(0)
aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
}
val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)).getType.getSqlTypeName
val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
.getSqlTypeName
aggregateCall.getAggregation match {
case _: SqlSumAggFunction | _: SqlSumEmptyIsZeroAggFunction =>
......
......@@ -37,7 +37,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo
class BoundedProcessingOverRowProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val precedingOffset: Long,
private val forwardedFieldCount: Int,
private val aggregatesTypeInfo: RowTypeInfo,
......@@ -118,7 +118,7 @@ class BoundedProcessingOverRowProcessFunction(
i = 0
while (i < aggregates.length) {
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).retract(accumulator, retractList.get(0).getField(aggFields(i)))
aggregates(i).retract(accumulator, retractList.get(0).getField(aggFields(i)(0)))
i += 1
}
retractList.remove(0)
......@@ -157,7 +157,7 @@ class BoundedProcessingOverRowProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)(0)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
......
......@@ -38,7 +38,7 @@ import org.apache.flink.util.{Collector, Preconditions}
*/
class DataSetAggFunction(
private val aggregates: Array[AggregateFunction[_ <: Any]],
private val aggInFields: Array[Int],
private val aggInFields: Array[Array[Int]],
private val aggOutMapping: Array[(Int, Int)],
private val gkeyOutMapping: Array[(Int, Int)],
private val groupingSetsMapping: Array[(Int, Int)],
......@@ -82,7 +82,7 @@ class DataSetAggFunction(
// accumulate
i = 0
while (i < aggregates.length) {
aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)))
aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)(0)))
i += 1
}
......
......@@ -35,7 +35,7 @@ import org.apache.flink.util.{Collector, Preconditions}
*/
class DataSetPreAggFunction(
private val aggregates: Array[AggregateFunction[_ <: Any]],
private val aggInFields: Array[Int],
private val aggInFields: Array[Array[Int]],
private val groupingKeys: Array[Int])
extends AbstractRichFunction
with GroupCombineFunction[Row, Row]
......@@ -78,7 +78,7 @@ class DataSetPreAggFunction(
// accumulate
i = 0
while (i < aggregates.length) {
aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)))
aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)(0)))
i += 1
}
// check if this record is the last record
......
......@@ -36,7 +36,7 @@ import org.apache.flink.util.Preconditions
*/
class DataSetWindowAggMapFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val groupingKeys: Array[Int],
private val timeFieldPos: Int, // time field position in input row
private val tumbleTimeWindowSize: Option[Long],
......@@ -62,7 +62,7 @@ class DataSetWindowAggMapFunction(
var i = 0
while (i < aggregates.length) {
val agg = aggregates(i)
val fieldValue = input.getField(aggFields(i))
val fieldValue = input.getField(aggFields(i)(0))
val accumulator = agg.createAccumulator()
agg.accumulate(accumulator, fieldValue)
output.setField(groupingKeys.length + i, accumulator)
......
......@@ -40,7 +40,7 @@ import org.apache.flink.util.{Collector, Preconditions}
*/
class RangeClauseBoundedOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo,
private val inputRowType: RowTypeInfo,
......@@ -160,7 +160,7 @@ class RangeClauseBoundedOverProcessFunction(
val accumulator = accumulators.getField(aggregatesIndex).asInstanceOf[Accumulator]
aggregates(aggregatesIndex)
.retract(accumulator, retractDataList.get(dataListIndex)
.getField(aggFields(aggregatesIndex)))
.getField(aggFields(aggregatesIndex)(0)))
aggregatesIndex += 1
}
dataListIndex += 1
......@@ -177,7 +177,7 @@ class RangeClauseBoundedOverProcessFunction(
while (aggregatesIndex < aggregates.length) {
val accumulator = accumulators.getField(aggregatesIndex).asInstanceOf[Accumulator]
aggregates(aggregatesIndex).accumulate(accumulator, inputs.get(dataListIndex)
.getField(aggFields(aggregatesIndex)))
.getField(aggFields(aggregatesIndex)(0)))
aggregatesIndex += 1
}
dataListIndex += 1
......
......@@ -41,7 +41,7 @@ import org.apache.flink.util.{Collector, Preconditions}
*/
class RowsClauseBoundedOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo,
private val inputRowType: RowTypeInfo,
......@@ -202,7 +202,7 @@ class RowsClauseBoundedOverProcessFunction(
i = 0
while (i < aggregates.length) {
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).retract(accumulator, retractRow.getField(aggFields(i)))
aggregates(i).retract(accumulator, retractRow.getField(aggFields(i)(0)))
i += 1
}
}
......@@ -212,7 +212,7 @@ class RowsClauseBoundedOverProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)(0)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
......
......@@ -43,7 +43,7 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
*/
abstract class UnboundedEventTimeOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val forwardedFieldCount: Int,
private val intermediateType: TypeInformation[Row],
private val inputType: TypeInformation[Row])
......@@ -217,7 +217,7 @@ abstract class UnboundedEventTimeOverProcessFunction(
*/
class UnboundedEventTimeRowsOverProcessFunction(
aggregates: Array[AggregateFunction[_]],
aggFields: Array[Int],
aggFields: Array[Array[Int]],
forwardedFieldCount: Int,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
......@@ -250,7 +250,7 @@ class UnboundedEventTimeRowsOverProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)(0)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
......@@ -269,7 +269,7 @@ class UnboundedEventTimeRowsOverProcessFunction(
*/
class UnboundedEventTimeRangeOverProcessFunction(
aggregates: Array[AggregateFunction[_]],
aggFields: Array[Int],
aggFields: Array[Array[Int]],
forwardedFieldCount: Int,
intermediateType: TypeInformation[Row],
inputType: TypeInformation[Row])
......@@ -294,7 +294,7 @@ class UnboundedEventTimeRangeOverProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i)(0)))
i += 1
}
j += 1
......
......@@ -37,7 +37,7 @@ import org.apache.flink.util.{Collector, Preconditions}
*/
class UnboundedNonPartitionedProcessingOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo)
extends ProcessFunction[Row, Row] with CheckpointedFunction{
......@@ -82,7 +82,7 @@ class UnboundedNonPartitionedProcessingOverProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)(0)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
......
......@@ -28,7 +28,7 @@ import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
class UnboundedProcessingOverProcessFunction(
private val aggregates: Array[AggregateFunction[_]],
private val aggFields: Array[Int],
private val aggFields: Array[Array[Int]],
private val forwardedFieldCount: Int,
private val aggregationStateType: RowTypeInfo)
extends ProcessFunction[Row, Row]{
......@@ -75,7 +75,7 @@ class UnboundedProcessingOverProcessFunction(
while (i < aggregates.length) {
val index = forwardedFieldCount + i
val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
aggregates(i).accumulate(accumulator, input.getField(aggFields(i)(0)))
output.setField(index, aggregates(i).getValue(accumulator))
i += 1
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册