提交 5784f395 编写于 作者: F Fabian Hueske

[FLINK-3936] [tableAPI] Add MIN/MAX aggregation for Boolean.

This closes #2035
上级 ec38a212
......@@ -192,6 +192,8 @@ object AggregateUtil {
new FloatMinAggregate
case DOUBLE =>
new DoubleMinAggregate
case BOOLEAN =>
new BooleanMinAggregate
case sqlType: SqlTypeName =>
throw new TableException("Min aggregate does no support type:" + sqlType)
}
......@@ -209,6 +211,8 @@ object AggregateUtil {
new FloatMaxAggregate
case DOUBLE =>
new DoubleMaxAggregate
case BOOLEAN =>
new BooleanMaxAggregate
case sqlType: SqlTypeName =>
throw new TableException("Max aggregate does no support type:" + sqlType)
}
......
......@@ -20,9 +20,8 @@ package org.apache.flink.api.table.runtime.aggregate
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
abstract class MaxAggregate[T](implicit ord: Ordering[T]) extends Aggregate[T] {
private val numeric = implicitly[Numeric[T]]
protected var maxIndex = -1
/**
......@@ -49,7 +48,8 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
override def merge(intermediate: Row, buffer: Row): Unit = {
val partialValue = intermediate.productElement(maxIndex).asInstanceOf[T]
val bufferValue = buffer.productElement(maxIndex).asInstanceOf[T]
buffer.setField(maxIndex, numeric.max(partialValue, bufferValue))
val max: T = if (ord.compare(partialValue, bufferValue) > 0) partialValue else bufferValue
buffer.setField(maxIndex, max)
}
/**
......@@ -122,3 +122,12 @@ class DoubleMaxAggregate extends MaxAggregate[Double] {
intermediate.setField(maxIndex, Double.MinValue)
}
}
class BooleanMaxAggregate extends MaxAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(maxIndex, false)
}
}
......@@ -20,9 +20,8 @@ package org.apache.flink.api.table.runtime.aggregate
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.table.Row
abstract class MinAggregate[T: Numeric] extends Aggregate[T]{
abstract class MinAggregate[T](implicit ord: Ordering[T]) extends Aggregate[T]{
private val numeric = implicitly[Numeric[T]]
protected var minIndex: Int = _
/**
......@@ -49,7 +48,8 @@ abstract class MinAggregate[T: Numeric] extends Aggregate[T]{
override def merge(partial: Row, buffer: Row): Unit = {
val partialValue = partial.productElement(minIndex).asInstanceOf[T]
val bufferValue = buffer.productElement(minIndex).asInstanceOf[T]
buffer.setField(minIndex, numeric.min(partialValue, bufferValue))
val min: T = if (ord.compare(partialValue, bufferValue) < 0) partialValue else bufferValue
buffer.setField(minIndex, min)
}
/**
......@@ -122,3 +122,12 @@ class DoubleMinAggregate extends MinAggregate[Double] {
intermediate.setField(minIndex, Double.MaxValue)
}
}
class BooleanMinAggregate extends MinAggregate[Boolean] {
override def intermediateDataType = Array(BasicTypeInfo.BOOLEAN_TYPE_INFO)
override def initiate(intermediate: Row): Unit = {
intermediate.setField(minIndex, true)
}
}
......@@ -91,3 +91,32 @@ class DoubleMaxAggregateTest extends MaxAggregateTestBase[Double] {
override def aggregator: Aggregate[Double] = new DoubleMaxAggregate()
}
class BooleanMaxAggregateTest extends AggregateTestBase[Boolean] {
override def inputValueSets: Seq[Seq[Boolean]] = Seq(
Seq(
false,
false,
false
),
Seq(
true,
true,
true
),
Seq(
true,
false,
null.asInstanceOf[Boolean],
true,
false,
true,
null.asInstanceOf[Boolean]
)
)
override def expectedResults: Seq[Boolean] = Seq(false, true, true)
override def aggregator: Aggregate[Boolean] = new BooleanMaxAggregate()
}
......@@ -91,3 +91,32 @@ class DoubleMinAggregateTest extends MinAggregateTestBase[Double] {
override def aggregator: Aggregate[Double] = new DoubleMinAggregate()
}
class BooleanMinAggregateTest extends AggregateTestBase[Boolean] {
override def inputValueSets: Seq[Seq[Boolean]] = Seq(
Seq(
false,
false,
false
),
Seq(
true,
true,
true
),
Seq(
true,
false,
null.asInstanceOf[Boolean],
true,
false,
true,
null.asInstanceOf[Boolean]
)
)
override def expectedResults: Seq[Boolean] = Seq(false, true, false)
override def aggregator: Aggregate[Boolean] = new BooleanMinAggregate()
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册