提交 495f104b 编写于 作者: T twalthr 提交者: Fabian Hueske

[FLINK-5884] [table] Integrate time indicators for Table API & SQL

This closes #3808.
上级 28ab7375
......@@ -194,7 +194,30 @@ abstract class BatchTableEnvironment(
protected def registerDataSetInternal[T](
name: String, dataSet: DataSet[T], fields: Array[Expression]): Unit = {
val (fieldNames, fieldIndexes) = getFieldInfo[T](dataSet.getType, fields)
val (fieldNames, fieldIndexes) = getFieldInfo[T](
dataSet.getType,
fields,
ignoreTimeAttributes = true)
// validate and extract time attributes
val (rowtime, proctime) = validateAndExtractTimeAttributes(fieldNames, fieldIndexes, fields)
// don't allow proctime on batch
proctime match {
case Some(_) =>
throw new ValidationException(
"A proctime attribute is not allowed in a batch environment. " +
"Working with processing-time on batch would lead to non-deterministic results.")
case _ => // ok
}
// rowtime must not extend the schema of a batch table
rowtime match {
case Some((idx, _)) if idx >= dataSet.getType.getArity =>
throw new ValidationException(
"A rowtime attribute must be defined on an existing field in a batch environment.")
case _ => // ok
}
val dataSetTable = new DataSetTable[T](dataSet, fieldIndexes, fieldNames)
registerTableInternal(name, dataSetTable)
}
......
......@@ -30,6 +30,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.table.calcite.RelTimeIndicatorConverter
import org.apache.flink.table.explain.PlanJsonParser
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.plan.nodes.FlinkConventions
......@@ -86,47 +87,6 @@ abstract class StreamTableEnvironment(
/** Returns a unique table name according to the internal naming pattern. */
protected def createUniqueTableName(): String = "_DataStreamTable_" + nameCntr.getAndIncrement()
/**
* Returns field names and field positions for a given [[TypeInformation]].
*
* Field names are automatically extracted for
* [[org.apache.flink.api.common.typeutils.CompositeType]].
* The method fails if inputType is not a
* [[org.apache.flink.api.common.typeutils.CompositeType]].
*
* @param inputType The TypeInformation extract the field names and positions from.
* @tparam A The type of the TypeInformation.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
override protected[flink] def getFieldInfo[A](inputType: TypeInformation[A])
: (Array[String], Array[Int]) = {
val fieldInfo = super.getFieldInfo(inputType)
if (fieldInfo._1.contains("rowtime")) {
throw new TableException("'rowtime' ia a reserved field name in stream environment.")
}
fieldInfo
}
/**
* Returns field names and field positions for a given [[TypeInformation]] and [[Array]] of
* [[Expression]].
*
* @param inputType The [[TypeInformation]] against which the [[Expression]]s are evaluated.
* @param exprs The expressions that define the field names.
* @tparam A The type of the TypeInformation.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
override protected[flink] def getFieldInfo[A](
inputType: TypeInformation[A],
exprs: Array[Expression])
: (Array[String], Array[Int]) = {
val fieldInfo = super.getFieldInfo(inputType, exprs)
if (fieldInfo._1.contains("rowtime")) {
throw new TableException("'rowtime' is a reserved field name in stream environment.")
}
fieldInfo
}
/**
* Registers an external [[StreamTableSource]] in this [[TableEnvironment]]'s catalog.
* Registered tables can be referenced in SQL queries.
......@@ -145,6 +105,7 @@ abstract class StreamTableEnvironment(
"StreamTableEnvironment")
}
}
/**
* Writes a [[Table]] to a [[TableSink]].
*
......@@ -185,7 +146,9 @@ abstract class StreamTableEnvironment(
val dataStreamTable = new DataStreamTable[T](
dataStream,
fieldIndexes,
fieldNames
fieldNames,
None,
None
)
registerTableInternal(name, dataStreamTable)
}
......@@ -200,15 +163,26 @@ abstract class StreamTableEnvironment(
* @tparam T The type of the [[DataStream]].
*/
protected def registerDataStreamInternal[T](
name: String,
dataStream: DataStream[T],
fields: Array[Expression]): Unit = {
name: String,
dataStream: DataStream[T],
fields: Array[Expression])
: Unit = {
val (fieldNames, fieldIndexes) = getFieldInfo[T](
dataStream.getType,
fields,
ignoreTimeAttributes = false)
// validate and extract time attributes
val (rowtime, proctime) = validateAndExtractTimeAttributes(fieldNames, fieldIndexes, fields)
val (fieldNames, fieldIndexes) = getFieldInfo[T](dataStream.getType, fields)
val dataStreamTable = new DataStreamTable[T](
dataStream,
fieldIndexes,
fieldNames
fieldNames,
rowtime,
proctime
)
registerTableInternal(name, dataStreamTable)
}
......@@ -259,7 +233,10 @@ abstract class StreamTableEnvironment(
// 1. decorrelate
val decorPlan = RelDecorrelator.decorrelateQuery(relNode)
// 2. normalize the logical plan
// 2. convert time indicators
val convPlan = RelTimeIndicatorConverter.convert(decorPlan, getRelBuilder.getRexBuilder)
// 3. normalize the logical plan
val normRuleSet = getNormRuleSet
val normalizedPlan = if (normRuleSet.iterator().hasNext) {
runHepPlanner(HepMatchOrder.BOTTOM_UP, normRuleSet, decorPlan, decorPlan.getTraitSet)
......@@ -267,7 +244,7 @@ abstract class StreamTableEnvironment(
decorPlan
}
// 3. optimize the logical Flink plan
// 4. optimize the logical Flink plan
val logicalOptRuleSet = getLogicalOptRuleSet
val logicalOutputProps = relNode.getTraitSet.replace(FlinkConventions.LOGICAL).simplify()
val logicalPlan = if (logicalOptRuleSet.iterator().hasNext) {
......@@ -276,7 +253,7 @@ abstract class StreamTableEnvironment(
normalizedPlan
}
// 4. optimize the physical Flink plan
// 5. optimize the physical Flink plan
val physicalOptRuleSet = getPhysicalOptRuleSet
val physicalOutputProps = relNode.getTraitSet.replace(FlinkConventions.DATASTREAM).simplify()
val physicalPlan = if (physicalOptRuleSet.iterator().hasNext) {
......@@ -285,7 +262,7 @@ abstract class StreamTableEnvironment(
logicalPlan
}
// 5. decorate the optimized plan
// 6. decorate the optimized plan
val decoRuleSet = getDecoRuleSet
val decoratedPlan = if (decoRuleSet.iterator().hasNext) {
runHepPlanner(HepMatchOrder.BOTTOM_UP, decoRuleSet, physicalPlan, physicalPlan.getTraitSet)
......
......@@ -52,6 +52,9 @@ import org.apache.flink.table.codegen.{CodeGenerator, ExpressionReducer}
import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, AggregateFunction}
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.plan.cost.DataSetCostFactory
import org.apache.flink.table.plan.logical.{CatalogNode, LogicalRelNode}
import org.apache.flink.table.plan.rules.FlinkRuleSets
......@@ -598,70 +601,94 @@ abstract class TableEnvironment(val config: TableConfig) {
/**
* Returns field names and field positions for a given [[TypeInformation]] and [[Array]] of
* [[Expression]].
* [[Expression]]. It does not handle time attributes but considers them in indices, if
* ignore flag is not false.
*
* @param inputType The [[TypeInformation]] against which the [[Expression]]s are evaluated.
* @param exprs The expressions that define the field names.
* @param ignoreTimeAttributes ignore time attributes and handle them as regular expressions.
* @tparam A The type of the TypeInformation.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
protected[flink] def getFieldInfo[A](
inputType: TypeInformation[A],
exprs: Array[Expression]): (Array[String], Array[Int]) = {
inputType: TypeInformation[A],
exprs: Array[Expression],
ignoreTimeAttributes: Boolean)
: (Array[String], Array[Int]) = {
TableEnvironment.validateType(inputType)
val filteredExprs = if (ignoreTimeAttributes) {
exprs.map {
case ta: TimeAttribute => ta.expression
case e@_ => e
}
} else {
exprs
}
val indexedNames: Array[(Int, String)] = inputType match {
case g: GenericTypeInfo[A] if g.getTypeClass == classOf[Row] =>
throw new TableException(
"An input of GenericTypeInfo<Row> cannot be converted to Table. " +
"Please specify the type of the input with a RowTypeInfo.")
case a: AtomicType[A] =>
if (exprs.length != 1) {
throw new TableException("Table of atomic type can only have a single field.")
}
exprs.map {
case UnresolvedFieldReference(name) => (0, name)
filteredExprs.zipWithIndex flatMap {
case (UnresolvedFieldReference(name), idx) =>
if (idx > 0) {
throw new TableException("Table of atomic type can only have a single field.")
}
Some((0, name))
case (_: TimeAttribute, _) if ignoreTimeAttributes =>
None
case _ => throw new TableException("Field reference expression requested.")
}
case t: TupleTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
filteredExprs.zipWithIndex flatMap {
case (UnresolvedFieldReference(name), idx) =>
Some((idx, name))
case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = t.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $t")
}
(idx, name)
Some((idx, name))
case (_: TimeAttribute, _) =>
None
case _ => throw new TableException(
"Field reference expression or alias on field expression expected.")
}
case c: CaseClassTypeInfo[A] =>
exprs.zipWithIndex.map {
case (UnresolvedFieldReference(name), idx) => (idx, name)
filteredExprs.zipWithIndex flatMap {
case (UnresolvedFieldReference(name), idx) =>
Some((idx, name))
case (Alias(UnresolvedFieldReference(origName), name, _), _) =>
val idx = c.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $c")
}
(idx, name)
Some((idx, name))
case (_: TimeAttribute, _) =>
None
case _ => throw new TableException(
"Field reference expression or alias on field expression expected.")
}
case p: PojoTypeInfo[A] =>
exprs.map {
filteredExprs flatMap {
case (UnresolvedFieldReference(name)) =>
val idx = p.getFieldIndex(name)
if (idx < 0) {
throw new TableException(s"$name is not a field of type $p")
}
(idx, name)
Some((idx, name))
case Alias(UnresolvedFieldReference(origName), name, _) =>
val idx = p.getFieldIndex(origName)
if (idx < 0) {
throw new TableException(s"$origName is not a field of type $p")
}
(idx, name)
Some((idx, name))
case _: TimeAttribute =>
None
case _ => throw new TableException(
"Field reference expression or alias on field expression expected.")
}
......@@ -795,6 +822,42 @@ abstract class TableEnvironment(val config: TableConfig) {
Some(mapFunction)
}
/**
* Checks for at most one rowtime and proctime attribute.
* Returns the time attributes.
*
* @return rowtime attribute and proctime attribute
*/
protected def validateAndExtractTimeAttributes(
fieldNames: Seq[String],
fieldIndices: Seq[Int],
exprs: Array[Expression])
: (Option[(Int, String)], Option[(Int, String)]) = {
var rowtime: Option[(Int, String)] = None
var proctime: Option[(Int, String)] = None
exprs.zipWithIndex.foreach {
case (RowtimeAttribute(reference@UnresolvedFieldReference(name)), idx) =>
if (rowtime.isDefined) {
throw new TableException(
"The rowtime attribute can only be defined once in a table schema.")
} else {
rowtime = Some(idx, name)
}
case (ProctimeAttribute(reference@UnresolvedFieldReference(name)), idx) =>
if (proctime.isDefined) {
throw new TableException(
"The proctime attribute can only be defined once in a table schema.")
} else {
proctime = Some(idx, name)
}
case _ =>
// do nothing
}
(rowtime, proctime)
}
}
/**
......@@ -803,6 +866,10 @@ abstract class TableEnvironment(val config: TableConfig) {
*/
object TableEnvironment {
// default names that can be used in in TableSources etc.
val DEFAULT_ROWTIME_ATTRIBUTE = "rowtime"
val DEFAULT_PROCTIME_ATTRIBUTE = "proctime"
/**
* Returns a [[JavaBatchTableEnv]] for a Java [[JavaBatchExecEnv]].
*
......
......@@ -625,7 +625,7 @@ trait ImplicitExpressionOperations {
*/
def millis = milli
// row interval type
// Row interval type
/**
* Creates an interval of rows.
......@@ -634,6 +634,8 @@ trait ImplicitExpressionOperations {
*/
def rows = toRowInterval(expr)
// Advanced type helper functions
/**
* Accesses the field of a Flink composite type (such as Tuple, POJO, etc.) by name and
* returns it's value.
......@@ -680,6 +682,20 @@ trait ImplicitExpressionOperations {
* @return the first and only element of an array with a single element
*/
def element() = ArrayElement(expr)
// Schema definition
/**
* Declares a field as the rowtime attribute for indicating, accessing, and working in
* Flink's event time.
*/
def rowtime = RowtimeAttribute(expr)
/**
* Declares a field as the proctime attribute for indicating, accessing, and working in
* Flink's processing time.
*/
def proctime = ProctimeAttribute(expr)
}
/**
......
......@@ -18,8 +18,8 @@
package org.apache.flink.table.api.scala
import org.apache.flink.table.api.{OverWindowWithOrderBy, SessionWithGap, SlideWithSize, TumbleWithSize}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.api.{TumbleWithSize, OverWindowWithOrderBy, SlideWithSize, SessionWithGap}
/**
* Helper object for creating a tumbling window. Tumbling windows are consecutive, non-overlapping
......
......@@ -20,12 +20,12 @@ package org.apache.flink.table.api
import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
import org.apache.flink.table.calcite.FlinkRelBuilder
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.expressions.{Alias, Asc, Expression, ExpressionParser, Ordering, UnresolvedAlias, UnresolvedFieldReference}
import org.apache.flink.table.plan.logical.Minus
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.plan.ProjectionTranslator._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.logical.{Minus, _}
import org.apache.flink.table.sinks.TableSink
import _root_.scala.collection.JavaConverters._
......@@ -1015,13 +1015,7 @@ class WindowGroupedTable(
val projectsOnAgg = replaceAggregationsAndProperties(
fields, table.tableEnv, aggNames, propNames)
val projectFields = (table.tableEnv, window) match {
// event time can be arbitrary field in batch environment
case (_: BatchTableEnvironment, w: EventTimeWindow) =>
extractFieldReferences(fields ++ groupKeys ++ Seq(w.timeField))
case (_, _) =>
extractFieldReferences(fields ++ groupKeys)
}
val projectFields = extractFieldReferences(fields ++ groupKeys :+ window.timeField)
new Table(table.tableEnv,
Project(
......
......@@ -149,7 +149,7 @@ class OverWindowWithOrderBy(
* A window specification.
*
* Window groups rows based on time or row-count intervals. It is a general way to group the
* elements, which is very helpful for both groupby-aggregations and over-aggregations to
* elements, which is very helpful for both groupBy-aggregations and over-aggregations to
* compute aggregates on groups of elements.
*
* Infinite streaming tables can only be grouped into time or row intervals. Hence window grouping
......@@ -157,111 +157,73 @@ class OverWindowWithOrderBy(
*
* For finite batch tables, window provides shortcuts for time-based groupBy.
*
* @param alias The expression of alias for this Window
*/
abstract class Window(val alias: Expression) {
abstract class Window(val alias: Expression, val timeField: Expression) {
/**
* Converts an API class to a logical window for planning.
*/
private[flink] def toLogicalWindow: LogicalWindow
}
// ------------------------------------------------------------------------------------------------
// Tumbling windows
// ------------------------------------------------------------------------------------------------
/**
* A window specification without alias.
* Tumbling window.
*
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param size the size of the window either as time or row-count interval.
*/
abstract class WindowWithoutAlias {
class TumbleWithSize(size: Expression) {
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
* Tumbling window.
*
* @param alias alias for this window
* @return this window
*/
def as(alias: Expression): Window
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* @param alias alias for this window
* @return this window
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param size the size of the window either as time or row-count interval.
*/
def as(alias: String): Window = as(ExpressionParser.parseExpression(alias))
}
/**
* A predefined specification of window on processing-time
*/
abstract class ProcTimeWindowWithoutAlias extends WindowWithoutAlias {
def this(size: String) = this(ExpressionParser.parseExpression(size))
/**
* Specifies the time attribute on which rows are grouped.
*
* For streaming tables call [[on('rowtime)]] to specify grouping by event-time. Otherwise rows
* are grouped by processing-time.
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables, refer to a timestamp or long attribute.
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* @return a predefined window on event-time
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
def on(timeField: Expression): WindowWithoutAlias
def on(timeField: Expression): TumbleWithSizeOnTime =
new TumbleWithSizeOnTime(timeField, size)
/**
* Specifies the time attribute on which rows are grouped.
*
* For streaming tables call [[on('rowtime)]] to specify grouping by event-time. Otherwise rows
* are grouped by processing-time.
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables, refer to a timestamp or long attribute.
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* @return a predefined window on event-time
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
def on(timeField: String): WindowWithoutAlias =
def on(timeField: String): TumbleWithSizeOnTime =
on(ExpressionParser.parseExpression(timeField))
}
/**
* A window operating on event-time.
*
* For streaming tables call on('rowtime) to specify grouping by event-time.
* Otherwise rows are grouped by processing-time.
*
* For batch tables, refer to a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* Tumbling window on time.
*/
abstract class EventTimeWindow(alias: Expression, val timeField: Expression) extends Window(alias)
// ------------------------------------------------------------------------------------------------
// Tumbling windows
// ------------------------------------------------------------------------------------------------
/**
* A partial specification of a tumbling window.
*
* @param size the size of the window either a time or a row-count interval.
*/
class TumbleWithSize(size: Expression) extends ProcTimeWindowWithoutAlias {
def this(size: String) = this(ExpressionParser.parseExpression(size))
/**
* Specifies the time attribute on which rows are grouped.
*
* For streaming tables call [[on('rowtime)]] to specify grouping by event-time.
* Otherwise rows are grouped by processing-time.
*
* For batch tables, refer to a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* @return a predefined window on event-time
*/
override def on(timeField: Expression): WindowWithoutAlias =
new TumbleWithoutAlias(timeField, size)
class TumbleWithSizeOnTime(time: Expression, size: Expression) {
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
......@@ -270,15 +232,9 @@ class TumbleWithSize(size: Expression) extends ProcTimeWindowWithoutAlias {
* @param alias alias for this window
* @return this window
*/
override def as(alias: Expression) = new TumblingWindow(alias, size)
}
/**
* A tumbling window on event-time without alias.
*/
class TumbleWithoutAlias(
time: Expression,
size: Expression) extends WindowWithoutAlias {
def as(alias: Expression): TumbleWithSizeOnTimeWithAlias = {
new TumbleWithSizeOnTimeWithAlias(alias, time, size)
}
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
......@@ -287,31 +243,28 @@ class TumbleWithoutAlias(
* @param alias alias for this window
* @return this window
*/
override def as(alias: Expression): Window = new TumblingEventTimeWindow(alias, time, size)
}
/**
* Tumbling window on processing-time.
*
* @param alias the alias of the window.
* @param size the size of the window either a time or a row-count interval.
*/
class TumblingWindow(alias: Expression, size: Expression) extends Window(alias) {
override private[flink] def toLogicalWindow: LogicalWindow =
ProcessingTimeTumblingGroupWindow(alias, size)
def as(alias: String): TumbleWithSizeOnTimeWithAlias = {
as(ExpressionParser.parseExpression(alias))
}
}
/**
* Tumbling window on event-time.
* Tumbling window on time with alias. Fully specifies a window.
*/
class TumblingEventTimeWindow(
class TumbleWithSizeOnTimeWithAlias(
alias: Expression,
time: Expression,
size: Expression) extends EventTimeWindow(alias, time) {
timeField: Expression,
size: Expression)
extends Window(
alias,
timeField) {
override private[flink] def toLogicalWindow: LogicalWindow =
EventTimeTumblingGroupWindow(alias, time, size)
/**
* Converts an API class to a logical window for planning.
*/
override private[flink] def toLogicalWindow: LogicalWindow = {
TumblingGroupWindow(alias, timeField, size)
}
}
// ------------------------------------------------------------------------------------------------
......@@ -319,16 +272,16 @@ class TumblingEventTimeWindow(
// ------------------------------------------------------------------------------------------------
/**
* A partially specified sliding window.
* Partially specified sliding window.
*
* @param size the size of the window either a time or a row-count interval.
* @param size the size of the window either as time or row-count interval.
*/
class SlideWithSize(size: Expression) {
/**
* A partially specified sliding window.
* Partially specified sliding window.
*
* @param size the size of the window either a time or a row-count interval.
* @param size the size of the window either as time or row-count interval.
*/
def this(size: String) = this(ExpressionParser.parseExpression(size))
......@@ -343,9 +296,9 @@ class SlideWithSize(size: Expression) {
* windows.
*
* @param slide the slide of the window either as time or row-count interval.
* @return a predefined sliding window.
* @return a sliding window
*/
def every(slide: Expression): SlideWithSlide = new SlideWithSlide(size, slide)
def every(slide: Expression): SlideWithSizeAndSlide = new SlideWithSizeAndSlide(size, slide)
/**
* Specifies the window's slide as time or row-count interval.
......@@ -358,48 +311,54 @@ class SlideWithSize(size: Expression) {
* windows.
*
* @param slide the slide of the window either as time or row-count interval.
* @return a predefined sliding window.
* @return a sliding window
*/
def every(slide: String): WindowWithoutAlias = every(ExpressionParser.parseExpression(slide))
def every(slide: String): SlideWithSizeAndSlide = every(ExpressionParser.parseExpression(slide))
}
/**
* A partially defined sliding window.
* Sliding window.
*
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param size the size of the window either as time or row-count interval.
*/
class SlideWithSlide(
size: Expression,
slide: Expression) extends ProcTimeWindowWithoutAlias {
class SlideWithSizeAndSlide(size: Expression, slide: Expression) {
/**
* Specifies the time attribute on which rows are grouped.
*
* For streaming tables call [[on('rowtime)]] to specify grouping by event-time. Otherwise rows
* are grouped by processing-time.
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables, refer to a timestamp or long attribute.
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* @return a predefined Sliding window on event-time.
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
override def on(timeField: Expression): SlideWithoutAlias =
new SlideWithoutAlias(timeField, size, slide)
def on(timeField: Expression): SlideWithSizeAndSlideOnTime =
new SlideWithSizeAndSlideOnTime(timeField, size, slide)
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
* Specifies the time attribute on which rows are grouped.
*
* @param alias alias for this window
* @return this window
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
override def as(alias: Expression): Window = new SlidingWindow(alias, size, slide)
def on(timeField: String): SlideWithSizeAndSlideOnTime =
on(ExpressionParser.parseExpression(timeField))
}
/**
* A partially defined sliding window on event-time without alias.
* Sliding window on time.
*/
class SlideWithoutAlias(
timeField: Expression,
size: Expression,
slide: Expression) extends WindowWithoutAlias {
class SlideWithSizeAndSlideOnTime(timeField: Expression, size: Expression, slide: Expression) {
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
......@@ -407,39 +366,40 @@ class SlideWithoutAlias(
* @param alias alias for this window
* @return this window
*/
override def as(alias: Expression): Window =
new SlidingEventTimeWindow(alias, timeField, size, slide)
}
def as(alias: Expression): SlideWithSizeAndSlideOnTimeWithAlias = {
new SlideWithSizeAndSlideOnTimeWithAlias(alias, timeField, size, slide)
}
/**
* A sliding window on processing-time.
*
* @param alias the alias of the window.
* @param size the size of the window either a time or a row-count interval.
* @param slide the interval by which the window slides.
*/
class SlidingWindow(
alias: Expression,
size: Expression,
slide: Expression)
extends Window(alias) {
override private[flink] def toLogicalWindow: LogicalWindow =
ProcessingTimeSlidingGroupWindow(alias, size, slide)
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
*
* @param alias alias for this window
* @return this window
*/
def as(alias: String): SlideWithSizeAndSlideOnTimeWithAlias = {
as(ExpressionParser.parseExpression(alias))
}
}
/**
* A sliding window on event-time.
* Sliding window on time with alias. Fully specifies a window.
*/
class SlidingEventTimeWindow(
class SlideWithSizeAndSlideOnTimeWithAlias(
alias: Expression,
timeField: Expression,
size: Expression,
slide: Expression)
extends EventTimeWindow(alias, timeField) {
extends Window(
alias,
timeField) {
override private[flink] def toLogicalWindow: LogicalWindow =
EventTimeSlidingGroupWindow(alias, timeField, size, slide)
/**
* Converts an API class to a logical window for planning.
*/
override private[flink] def toLogicalWindow: LogicalWindow = {
SlidingGroupWindow(alias, timeField, size, slide)
}
}
// ------------------------------------------------------------------------------------------------
......@@ -447,42 +407,59 @@ class SlidingEventTimeWindow(
// ------------------------------------------------------------------------------------------------
/**
* A partially defined session window.
* Session window.
*
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param gap the time interval of inactivity before a window is closed.
*/
class SessionWithGap(gap: Expression) extends ProcTimeWindowWithoutAlias {
class SessionWithGap(gap: Expression) {
/**
* Session window.
*
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param gap the time interval of inactivity before a window is closed.
*/
def this(gap: String) = this(ExpressionParser.parseExpression(gap))
/**
* Specifies the time attribute on which rows are grouped.
*
* For streaming tables call [[on('rowtime)]] to specify grouping by event-time. Otherwise rows
* are grouped by processing-time.
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables, refer to a timestamp or long attribute.
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time mode for streaming tables and time attribute for batch tables
* @return an on event-time session window on event-time
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
override def on(timeField: Expression): SessionWithoutAlias =
new SessionWithoutAlias(timeField, gap)
def on(timeField: Expression): SessionWithGapOnTime =
new SessionWithGapOnTime(timeField, gap)
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
* Specifies the time attribute on which rows are grouped.
*
* @param alias alias for this window
* @return this window
* For streaming tables you can specify grouping by a event-time or processing-time attribute.
*
* For batch tables you can specify grouping on a timestamp or long attribute.
*
* @param timeField time attribute for streaming and batch tables
* @return a tumbling window on event-time
*/
override def as(alias: Expression): Window = new SessionWindow(alias, gap)
def on(timeField: String): SessionWithGapOnTime =
on(ExpressionParser.parseExpression(timeField))
}
/**
* A partially defined session window on event-time without alias.
* Session window on time.
*/
class SessionWithoutAlias(
timeField: Expression,
gap: Expression) extends WindowWithoutAlias {
class SessionWithGapOnTime(timeField: Expression, gap: Expression) {
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
......@@ -490,29 +467,37 @@ class SessionWithoutAlias(
* @param alias alias for this window
* @return this window
*/
override def as(alias: Expression): Window = new SessionEventTimeWindow(alias, timeField, gap)
}
/**
* A session window on processing-time.
*
* @param gap the time interval of inactivity before a window is closed.
*/
class SessionWindow(alias: Expression, gap: Expression) extends Window(alias) {
def as(alias: Expression): SessionWithGapOnTimeWithAlias = {
new SessionWithGapOnTimeWithAlias(alias, timeField, gap)
}
override private[flink] def toLogicalWindow: LogicalWindow =
ProcessingTimeSessionGroupWindow(alias, gap)
/**
* Assigns an alias for this window that the following `groupBy()` and `select()` clause can
* refer to. `select()` statement can access window properties such as window start or end time.
*
* @param alias alias for this window
* @return this window
*/
def as(alias: String): SessionWithGapOnTimeWithAlias = {
as(ExpressionParser.parseExpression(alias))
}
}
/**
* A session window on event-time.
* Session window on time with alias. Fully specifies a window.
*/
class SessionEventTimeWindow(
class SessionWithGapOnTimeWithAlias(
alias: Expression,
timeField: Expression,
gap: Expression)
extends EventTimeWindow(alias, timeField) {
extends Window(
alias,
timeField) {
override private[flink] def toLogicalWindow: LogicalWindow =
EventTimeSessionGroupWindow(alias, timeField, gap)
/**
* Converts an API class to a logical window for planning.
*/
override private[flink] def toLogicalWindow: LogicalWindow = {
SessionGroupWindow(alias, timeField, gap)
}
}
......@@ -21,8 +21,8 @@ package org.apache.flink.table.calcite
import org.apache.calcite.adapter.java.JavaTypeFactory
import org.apache.calcite.prepare.CalciteCatalogReader
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.validate.{SqlConformance, SqlValidatorImpl}
import org.apache.calcite.sql.{SqlInsert, SqlOperatorTable}
import org.apache.calcite.sql._
import org.apache.calcite.sql.validate.{SqlConformanceEnum, SqlValidatorImpl}
/**
* This is a copy of Calcite's CalciteSqlValidator to use with [[FlinkPlannerImpl]].
......@@ -30,8 +30,12 @@ import org.apache.calcite.sql.{SqlInsert, SqlOperatorTable}
class FlinkCalciteSqlValidator(
opTab: SqlOperatorTable,
catalogReader: CalciteCatalogReader,
typeFactory: JavaTypeFactory) extends SqlValidatorImpl(
opTab, catalogReader, typeFactory, SqlConformance.DEFAULT) {
factory: JavaTypeFactory)
extends SqlValidatorImpl(
opTab,
catalogReader,
factory,
SqlConformanceEnum.DEFAULT) {
override def getLogicalSourceRowType(
sourceRowType: RelDataType,
......
......@@ -107,7 +107,11 @@ class FlinkPlannerImpl(
// we disable automatic flattening in order to let composite types pass without modification
// we might enable it again once Calcite has better support for structured types
// root = root.withRel(sqlToRelConverter.flattenTypes(root.rel, true))
root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel))
// TableEnvironment.optimize will execute the following
// root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel))
// convert time indicators
// root = root.withRel(RelTimeIndicatorConverter.convert(root.rel, rexBuilder))
root
} catch {
case e: RelConversionException => throw TableException(e.getMessage)
......
......@@ -20,25 +20,25 @@ package org.apache.flink.table.calcite
import org.apache.calcite.avatica.util.TimeUnit
import org.apache.calcite.jdbc.JavaTypeFactoryImpl
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
import org.apache.calcite.rel.`type`._
import org.apache.calcite.sql.SqlIntervalQualifier
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName}
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.api.java.typeutils.ValueTypeInfo._
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.schema.{ArrayRelDataType, CompositeRelDataType, GenericRelDataType, MapRelDataType}
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.table.typeutils.TypeCheckUtils.isSimple
import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName
import org.apache.flink.table.plan.schema._
import org.apache.flink.table.typeutils.TypeCheckUtils.isSimple
import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TimeIntervalTypeInfo}
import org.apache.flink.types.Row
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.collection.mutable
/**
* Flink specific type factory that represents the interface between Flink's [[TypeInformation]]
......@@ -65,6 +65,12 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
createSqlIntervalType(
new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.SECOND, SqlParserPos.ZERO))
case TimeIndicatorTypeInfo.ROWTIME_INDICATOR =>
createRowtimeIndicatorType()
case TimeIndicatorTypeInfo.PROCTIME_INDICATOR =>
createProctimeIndicatorType()
case _ =>
createSqlType(sqlType)
}
......@@ -76,24 +82,76 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
}
}
/**
* Creates a indicator type for processing-time, but with similar properties as SQL timestamp.
*/
def createProctimeIndicatorType(): RelDataType = {
val originalType = createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP)
canonize(
new TimeIndicatorRelDataType(
getTypeSystem,
originalType.asInstanceOf[BasicSqlType],
isEventTime = false)
)
}
/**
* Creates a indicator type for event-time, but with similar properties as SQL timestamp.
*/
def createRowtimeIndicatorType(): RelDataType = {
val originalType = createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP)
canonize(
new TimeIndicatorRelDataType(
getTypeSystem,
originalType.asInstanceOf[BasicSqlType],
isEventTime = true)
)
}
/**
* Creates a struct type with the input fieldNames and input fieldTypes using FlinkTypeFactory
*
* @param fieldNames field names
* @param fieldTypes field types, every element is Flink's [[TypeInformation]]
* @return a struct type with the input fieldNames and input fieldTypes
* @param rowtime optional system field to indicate event-time; the index determines the index
* in the final record and might replace an existing field
* @param proctime optional system field to indicate processing-time; the index determines the
* index in the final record and might replace an existing field
* @return a struct type with the input fieldNames, input fieldTypes, and system fields
*/
def buildRowDataType(
fieldNames: Array[String],
fieldTypes: Array[TypeInformation[_]])
def buildLogicalRowType(
fieldNames: Seq[String],
fieldTypes: Seq[TypeInformation[_]],
rowtime: Option[(Int, String)],
proctime: Option[(Int, String)])
: RelDataType = {
val rowDataTypeBuilder = builder
fieldNames
.zip(fieldTypes)
.foreach { f =>
rowDataTypeBuilder.add(f._1, createTypeFromTypeInfo(f._2)).nullable(true)
val logicalRowTypeBuilder = builder
val fields = fieldNames.zip(fieldTypes)
var totalNumberOfFields = fields.length
if (rowtime.isDefined) {
totalNumberOfFields += 1
}
if (proctime.isDefined) {
totalNumberOfFields += 1
}
var addedTimeAttributes = 0
for (i <- 0 until totalNumberOfFields) {
if (rowtime.isDefined && rowtime.get._1 == i) {
logicalRowTypeBuilder.add(rowtime.get._2, createRowtimeIndicatorType())
addedTimeAttributes += 1
} else if (proctime.isDefined && proctime.get._1 == i) {
logicalRowTypeBuilder.add(proctime.get._2, createProctimeIndicatorType())
addedTimeAttributes += 1
} else {
val field = fields(i - addedTimeAttributes)
logicalRowTypeBuilder.add(field._1, createTypeFromTypeInfo(field._2)).nullable(true)
}
rowDataTypeBuilder.build
}
logicalRowTypeBuilder.build
}
override def createSqlType(typeName: SqlTypeName, precision: Int): RelDataType = {
......@@ -178,6 +236,7 @@ object FlinkTypeFactory {
/**
* Converts a Calcite logical record into a Flink type information.
*/
@deprecated("Use the RowSchema class instead because it handles both logical and physical rows.")
def toInternalRowTypeInfo(logicalRowType: RelDataType): TypeInformation[Row] = {
// convert to type information
val logicalFieldTypes = logicalRowType.getFieldList.asScala map { relDataType =>
......@@ -188,6 +247,36 @@ object FlinkTypeFactory {
new RowTypeInfo(logicalFieldTypes.toArray, logicalFieldNames.toArray)
}
def isProctimeIndicatorType(relDataType: RelDataType): Boolean = relDataType match {
case ti: TimeIndicatorRelDataType if !ti.isEventTime => true
case _ => false
}
def isProctimeIndicatorType(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case ti: TimeIndicatorTypeInfo if !ti.isEventTime => true
case _ => false
}
def isRowtimeIndicatorType(relDataType: RelDataType): Boolean = relDataType match {
case ti: TimeIndicatorRelDataType if ti.isEventTime => true
case _ => false
}
def isRowtimeIndicatorType(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case ti: TimeIndicatorTypeInfo if ti.isEventTime => true
case _ => false
}
def isTimeIndicatorType(relDataType: RelDataType): Boolean = relDataType match {
case ti: TimeIndicatorRelDataType => true
case _ => false
}
def isTimeIndicatorType(typeInfo: TypeInformation[_]): Boolean = typeInfo match {
case ti: TimeIndicatorTypeInfo => true
case _ => false
}
def toTypeInfo(relDataType: RelDataType): TypeInformation[_] = relDataType.getSqlTypeName match {
case BOOLEAN => BOOLEAN_TYPE_INFO
case TINYINT => BYTE_TYPE_INFO
......@@ -199,6 +288,15 @@ object FlinkTypeFactory {
case VARCHAR | CHAR => STRING_TYPE_INFO
case DECIMAL => BIG_DEC_TYPE_INFO
// time indicators
case TIMESTAMP if relDataType.isInstanceOf[TimeIndicatorRelDataType] =>
val indicator = relDataType.asInstanceOf[TimeIndicatorRelDataType]
if (indicator.isEventTime) {
TimeIndicatorTypeInfo.ROWTIME_INDICATOR
} else {
TimeIndicatorTypeInfo.PROCTIME_INDICATOR
}
// temporal types
case DATE => SqlTimeTypeInfo.DATE
case TIME => SqlTimeTypeInfo.TIME
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.calcite
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl, RelRecordType, StructKind}
import org.apache.calcite.rel.logical._
import org.apache.calcite.rel.{RelNode, RelShuttleImpl}
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.calcite.FlinkTypeFactory.isTimeIndicatorType
import org.apache.flink.table.functions.TimeMaterializationSqlFunction
import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType
import scala.collection.JavaConversions._
/**
* Traverses a [[RelNode]] tree and converts fields with [[TimeIndicatorRelDataType]] type. If a
* time attribute is accessed for a calculation, it will be materialized. Forwarding is allowed in
* some cases, but not all.
*/
class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttleImpl {
override def visit(project: LogicalProject): RelNode = {
// visit children and update inputs
val updatedProject = super.visit(project).asInstanceOf[LogicalProject]
// check if input field contains time indicator type
// materialize field if no time indicator is present anymore
// if input field is already materialized, change to timestamp type
val materializer = new RexTimeIndicatorMaterializer(
rexBuilder,
updatedProject.getInput.getRowType.getFieldList.map(_.getType))
val newProjects = updatedProject.getProjects.map(_.accept(materializer))
// copy project
updatedProject.copy(
updatedProject.getTraitSet,
updatedProject.getInput,
newProjects,
buildRowType(updatedProject.getRowType.getFieldNames, newProjects.map(_.getType))
)
}
override def visit(filter: LogicalFilter): RelNode = {
// visit children and update inputs
val updatedFilter = super.visit(filter).asInstanceOf[LogicalFilter]
// check if input field contains time indicator type
// materialize field if no time indicator is present anymore
// if input field is already materialized, change to timestamp type
val materializer = new RexTimeIndicatorMaterializer(
rexBuilder,
updatedFilter.getInput.getRowType.getFieldList.map(_.getType))
val newCondition = updatedFilter.getCondition.accept(materializer)
// copy filter
updatedFilter.copy(
updatedFilter.getTraitSet,
updatedFilter.getInput,
newCondition
)
}
override def visit(union: LogicalUnion): RelNode = {
// visit children and update inputs
val updatedUnion = super.visit(union).asInstanceOf[LogicalUnion]
// make sure that time indicator types match
val inputTypes = updatedUnion.getInputs.map(_.getRowType)
val head = inputTypes.head.getFieldList.map(_.getType)
val isValid = inputTypes.forall { t =>
val fieldTypes = t.getFieldList.map(_.getType)
fieldTypes.zip(head).forall { case (l, r) =>
// check if time indicators match
if (isTimeIndicatorType(l) && isTimeIndicatorType(r)) {
val leftTime = l.asInstanceOf[TimeIndicatorRelDataType].isEventTime
val rightTime = r.asInstanceOf[TimeIndicatorRelDataType].isEventTime
leftTime == rightTime
}
// one side is not an indicator
else if (isTimeIndicatorType(l) || isTimeIndicatorType(r)) {
false
}
// uninteresting types
else {
true
}
}
}
if (!isValid) {
throw new ValidationException(
"Union fields with time attributes have different types.")
}
updatedUnion
}
override def visit(other: RelNode): RelNode = other match {
case scan: LogicalTableFunctionScan if
stack.size() > 0 && stack.peek().isInstanceOf[LogicalCorrelate] =>
// visit children and update inputs
val updatedScan = super.visit(scan).asInstanceOf[LogicalTableFunctionScan]
val correlate = stack.peek().asInstanceOf[LogicalCorrelate]
// check if input field contains time indicator type
// materialize field if no time indicator is present anymore
// if input field is already materialized, change to timestamp type
val materializer = new RexTimeIndicatorMaterializer(
rexBuilder,
correlate.getInputs.get(0).getRowType.getFieldList.map(_.getType))
val newCall = updatedScan.getCall.accept(materializer)
// copy scan
updatedScan.copy(
updatedScan.getTraitSet,
updatedScan.getInputs,
newCall,
updatedScan.getElementType,
updatedScan.getRowType,
updatedScan.getColumnMappings
)
case _ =>
super.visit(other)
}
private def buildRowType(names: Seq[String], types: Seq[RelDataType]): RelDataType = {
val fields = names.zipWithIndex.map { case (name, idx) =>
new RelDataTypeFieldImpl(name, idx, types(idx))
}
new RelRecordType(StructKind.FULLY_QUALIFIED, fields)
}
}
class RexTimeIndicatorMaterializer(
private val rexBuilder: RexBuilder,
private val input: Seq[RelDataType])
extends RexShuttle {
val timestamp = rexBuilder
.getTypeFactory
.asInstanceOf[FlinkTypeFactory]
.createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP)
override def visitInputRef(inputRef: RexInputRef): RexNode = {
// reference is interesting
if (isTimeIndicatorType(inputRef.getType)) {
val resolvedRefType = input(inputRef.getIndex)
// input is a valid time indicator
if (isTimeIndicatorType(resolvedRefType)) {
inputRef
}
// input has been materialized
else {
new RexInputRef(inputRef.getIndex, resolvedRefType)
}
}
// reference is a regular field
else {
super.visitInputRef(inputRef)
}
}
override def visitCall(call: RexCall): RexNode = {
val updatedCall = super.visitCall(call).asInstanceOf[RexCall]
// skip materialization for special operators
updatedCall.getOperator match {
case SqlStdOperatorTable.SESSION | SqlStdOperatorTable.HOP | SqlStdOperatorTable.TUMBLE =>
return updatedCall
case _ => // do nothing
}
// materialize operands with time indicators
val materializedOperands = updatedCall.getOperands.map { o =>
if (isTimeIndicatorType(o.getType)) {
rexBuilder.makeCall(TimeMaterializationSqlFunction, o)
} else {
o
}
}
// remove time indicator return type
if (isTimeIndicatorType(updatedCall.getType)) {
updatedCall.clone(timestamp, materializedOperands)
} else {
updatedCall.clone(updatedCall.getType, materializedOperands)
}
}
}
object RelTimeIndicatorConverter {
def convert(rootRel: RelNode, rexBuilder: RexBuilder): RelNode = {
val converter = new RelTimeIndicatorConverter(rexBuilder)
rootRel.accept(converter)
}
}
......@@ -23,7 +23,6 @@ import java.lang.{Iterable => JIterable}
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.avatica.util.DateTimeUtils
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.SqlTypeName._
......@@ -42,8 +41,8 @@ import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.calls.FunctionGenerator
import org.apache.flink.table.codegen.calls.ScalarOperators._
import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, UserDefinedFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, TimeMaterializationSqlFunction, UserDefinedFunction}
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.types.Row
......@@ -59,19 +58,18 @@ import scala.collection.mutable
* @param nullableInput input(s) can be null.
* @param input1 type information about the first input of the Function
* @param input2 type information about the second input if the Function is binary
* @param input1PojoFieldMapping additional mapping information if input1 is a POJO (POJO types
* have no deterministic field order).
* @param input2PojoFieldMapping additional mapping information if input2 is a POJO (POJO types
* have no deterministic field order).
*
* @param input1FieldMapping additional mapping information for input1
* (e.g. POJO types have no deterministic field order and some input fields might not be read)
* @param input2FieldMapping additional mapping information for input2
* (e.g. POJO types have no deterministic field order and some input fields might not be read)
*/
class CodeGenerator(
config: TableConfig,
nullableInput: Boolean,
input1: TypeInformation[_ <: Any],
input2: Option[TypeInformation[_ <: Any]] = None,
input1PojoFieldMapping: Option[Array[Int]] = None,
input2PojoFieldMapping: Option[Array[Int]] = None)
config: TableConfig,
nullableInput: Boolean,
input1: TypeInformation[_ <: Any],
input2: Option[TypeInformation[_ <: Any]] = None,
input1FieldMapping: Option[Array[Int]] = None,
input2FieldMapping: Option[Array[Int]] = None)
extends RexVisitor[GeneratedExpression] {
// check if nullCheck is enabled when inputs can be null
......@@ -82,7 +80,7 @@ class CodeGenerator(
// check for POJO input1 mapping
input1 match {
case pt: PojoTypeInfo[_] =>
input1PojoFieldMapping.getOrElse(
input1FieldMapping.getOrElse(
throw new CodeGenException("No input mapping is specified for input1 of type POJO."))
case _ => // ok
}
......@@ -90,11 +88,24 @@ class CodeGenerator(
// check for POJO input2 mapping
input2 match {
case Some(pt: PojoTypeInfo[_]) =>
input2PojoFieldMapping.getOrElse(
input2FieldMapping.getOrElse(
throw new CodeGenException("No input mapping is specified for input2 of type POJO."))
case _ => // ok
}
private val input1Mapping = input1FieldMapping match {
case Some(mapping) => mapping
case _ => (0 until input1.getArity).toArray
}
private val input2Mapping = input2FieldMapping match {
case Some(mapping) => mapping
case _ => input2 match {
case Some(input) => (0 until input.getArity).toArray
case _ => Array[Int]()
}
}
/**
* A code generator for generating unary Flink
* [[org.apache.flink.api.common.functions.Function]]s with one input.
......@@ -102,15 +113,15 @@ class CodeGenerator(
* @param config configuration that determines runtime behavior
* @param nullableInput input(s) can be null.
* @param input type information about the input of the Function
* @param inputPojoFieldMapping additional mapping information necessary if input is a
* POJO (POJO types have no deterministic field order).
* @param inputFieldMapping additional mapping information necessary for input
* (e.g. POJO types have no deterministic field order and some input fields might not be read)
*/
def this(
config: TableConfig,
nullableInput: Boolean,
input: TypeInformation[Any],
inputPojoFieldMapping: Array[Int]) =
this(config, nullableInput, input, None, Some(inputPojoFieldMapping))
inputFieldMapping: Array[Int]) =
this(config, nullableInput, input, None, Some(inputFieldMapping))
/**
* A code generator for generating Flink input formats.
......@@ -249,7 +260,7 @@ class CodeGenerator(
* @param name Class name of the function.
* Does not need to be unique but has to be a valid Java class identifier.
* @param generator The code generator instance
* @param inputType Input row type
* @param physicalInputTypes Physical input row types
* @param aggregates All aggregate functions
* @param aggFields Indexes of the input fields for all aggregate functions
* @param aggMapping The mapping of aggregates to output fields
......@@ -270,7 +281,7 @@ class CodeGenerator(
def generateAggregations(
name: String,
generator: CodeGenerator,
inputType: RelDataType,
physicalInputTypes: Seq[TypeInformation[_]],
aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
......@@ -295,8 +306,7 @@ class CodeGenerator(
val accTypes = accTypeClasses.map(_.getCanonicalName)
// get java classes of input fields
val javaClasses = inputType.getFieldList
.map(f => FlinkTypeFactory.toTypeInfo(f.getType).getTypeClass)
val javaClasses = physicalInputTypes.map(t => t.getTypeClass)
// get parameter lists for aggregation functions
val parameters = aggFields.map { inFields =>
val fields = for (f <- inFields) yield
......@@ -844,12 +854,12 @@ class CodeGenerator(
returnType: TypeInformation[_ <: Any],
resultFieldNames: Seq[String])
: GeneratedExpression = {
val input1AccessExprs = for (i <- 0 until input1.getArity)
yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
val input1AccessExprs = for (i <- 0 until input1.getArity if input1Mapping.contains(i))
yield generateInputAccess(input1, input1Term, i, input1Mapping)
val input2AccessExprs = input2 match {
case Some(ti) => for (i <- 0 until ti.getArity)
yield generateInputAccess(ti, input2Term, i, input2PojoFieldMapping)
case Some(ti) => for (i <- 0 until ti.getArity if input2Mapping.contains(i))
yield generateInputAccess(ti, input2Term, i, input2Mapping)
case None => Seq() // add nothing
}
......@@ -861,14 +871,14 @@ class CodeGenerator(
*/
def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = {
val input1AccessExprs = for (i <- 0 until input1.getArity)
yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping)
yield generateInputAccess(input1, input1Term, i, input1Mapping)
val input2AccessExprs = input2 match {
case Some(ti) => for (i <- 0 until ti.getArity)
case Some(ti) => for (i <- 0 until ti.getArity if input2Mapping.contains(i))
// use generateFieldAccess instead of generateInputAccess to avoid the generated table
// function's field access code is put on the top of function body rather than
// the while loop
yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping)
yield generateFieldAccess(ti, input2Term, i, input2Mapping)
case None => throw new CodeGenException("Type information of input2 must not be null.")
}
(input1AccessExprs, input2AccessExprs)
......@@ -1123,11 +1133,11 @@ class CodeGenerator(
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
// if inputRef index is within size of input1 we work with input1, input2 otherwise
val input = if (inputRef.getIndex < input1.getArity) {
(input1, input1Term, input1PojoFieldMapping)
(input1, input1Term, input1Mapping)
} else {
(input2.getOrElse(throw new CodeGenException("Invalid input access.")),
input2Term,
input2PojoFieldMapping)
input2Mapping)
}
val index = if (input._2 == input1Term) {
......@@ -1146,7 +1156,7 @@ class CodeGenerator(
refExpr.resultType,
refExpr.resultTerm,
index,
input1PojoFieldMapping)
input1Mapping)
val resultTerm = newName("result")
val nullTerm = newName("isNull")
......@@ -1302,6 +1312,11 @@ class CodeGenerator(
throw new CodeGenException("Dynamic parameter references are not supported yet.")
override def visitCall(call: RexCall): GeneratedExpression = {
// time materialization is not implemented yet
if (call.getOperator == TimeMaterializationSqlFunction) {
throw new CodeGenException("Access to time attributes is not possible yet.")
}
val operands = call.getOperands.map(_.accept(this))
val resultType = FlinkTypeFactory.toTypeInfo(call.getType)
......@@ -1546,7 +1561,7 @@ class CodeGenerator(
inputType: TypeInformation[_ <: Any],
inputTerm: String,
index: Int,
pojoFieldMapping: Option[Array[Int]])
fieldMapping: Array[Int])
: GeneratedExpression = {
// if input has been used before, we can reuse the code that
// has already been generated
......@@ -1558,9 +1573,9 @@ class CodeGenerator(
// generate input access and unboxing if necessary
case None =>
val expr = if (nullableInput) {
generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
generateNullableInputFieldAccess(inputType, inputTerm, index, fieldMapping)
} else {
generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
generateFieldAccess(inputType, inputTerm, index, fieldMapping)
}
reusableInputUnboxingExprs((inputTerm, index)) = expr
......@@ -1574,7 +1589,7 @@ class CodeGenerator(
inputType: TypeInformation[_ <: Any],
inputTerm: String,
index: Int,
pojoFieldMapping: Option[Array[Int]])
fieldMapping: Array[Int])
: GeneratedExpression = {
val resultTerm = newName("result")
val nullTerm = newName("isNull")
......@@ -1582,7 +1597,7 @@ class CodeGenerator(
val fieldType = inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
pojoFieldMapping.get(index)
fieldMapping(index)
}
else {
index
......@@ -1593,7 +1608,7 @@ class CodeGenerator(
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping)
val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, fieldMapping)
val inputCheckCode =
s"""
......@@ -1617,12 +1632,12 @@ class CodeGenerator(
inputType: TypeInformation[_],
inputTerm: String,
index: Int,
pojoFieldMapping: Option[Array[Int]])
fieldMapping: Array[Int])
: GeneratedExpression = {
inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && pojoFieldMapping.nonEmpty) {
pojoFieldMapping.get(index)
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
fieldMapping(index)
}
else {
index
......
......@@ -29,7 +29,6 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression}
import org.apache.flink.table.functions.{EventTimeExtractor, ProcTimeExtractor}
import org.apache.flink.table.functions.utils.{ScalarSqlFunction, TableSqlFunction}
import scala.collection.mutable
......@@ -496,15 +495,6 @@ object FunctionGenerator {
)
)
// generate a constant for time indicator functions.
// this is a temporary solution and will be removed when FLINK-5884 is implemented.
case ProcTimeExtractor | EventTimeExtractor =>
Some(new CallGenerator {
override def generate(codeGenerator: CodeGenerator, operands: Seq[GeneratedExpression]) = {
GeneratedExpression("0L", "false", "", SqlTimeTypeInfo.TIMESTAMP)
}
})
// built-in scalar function
case _ =>
sqlFunctions.get((sqlOperator, operandTypes))
......
......@@ -28,10 +28,49 @@ import org.apache.calcite.rex.{RexBuilder, RexNode}
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo}
import org.apache.flink.streaming.api.windowing.time.{Time => FlinkTime}
object ExpressionUtils {
private[flink] def isTimeIntervalLiteral(expr: Expression): Boolean = expr match {
case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => true
case _ => false
}
private[flink] def isRowCountLiteral(expr: Expression): Boolean = expr match {
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => true
case _ => false
}
private[flink] def isTimeAttribute(expr: Expression): Boolean = expr match {
case r: ResolvedFieldReference if FlinkTypeFactory.isTimeIndicatorType(r.resultType) => true
case _ => false
}
private[flink] def isRowtimeAttribute(expr: Expression): Boolean = expr match {
case r: ResolvedFieldReference if FlinkTypeFactory.isRowtimeIndicatorType(r.resultType) => true
case _ => false
}
private[flink] def isProctimeAttribute(expr: Expression): Boolean = expr match {
case r: ResolvedFieldReference if FlinkTypeFactory.isProctimeIndicatorType(r.resultType) =>
true
case _ => false
}
private[flink] def toTime(expr: Expression): FlinkTime = expr match {
case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
FlinkTime.milliseconds(value)
case _ => throw new IllegalArgumentException()
}
private[flink] def toLong(expr: Expression): Long = expr match {
case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value
case _ => throw new IllegalArgumentException()
}
private[flink] def toMonthInterval(expr: Expression, multiplier: Int): Expression = expr match {
case Literal(value: Int, BasicTypeInfo.INT_TYPE_INFO) =>
Literal(value * multiplier, TimeIntervalTypeInfo.INTERVAL_MONTHS)
......
......@@ -110,18 +110,11 @@ case class OverCall(
val aggExprs = agg.asInstanceOf[Aggregation].children.map(_.toRexNode(relBuilder)).asJava
// assemble order by key
val orderKey = orderBy match {
case _: RowTime =>
new RexFieldCollation(relBuilder.call(EventTimeExtractor), Set[SqlKind]().asJava)
case _: ProcTime =>
new RexFieldCollation(relBuilder.call(ProcTimeExtractor), Set[SqlKind]().asJava)
case _ =>
throw new ValidationException("Invalid OrderBy expression.")
}
val orderKey = new RexFieldCollation(orderBy.toRexNode, Set[SqlKind]().asJava)
val orderKeys = ImmutableList.of(orderKey)
// assemble partition by keys
val partitionKeys = partitionBy.map(_.toRexNode(relBuilder)).asJava
val partitionKeys = partitionBy.map(_.toRexNode).asJava
// assemble bounds
val isPhysical: Boolean = preceding.resultType.isInstanceOf[RowIntervalTypeInfo]
......@@ -249,6 +242,11 @@ case class OverCall(
return ValidationFailure("Preceding and following must be of same interval type.")
}
// check time field
if (!ExpressionUtils.isTimeAttribute(orderBy)) {
return ValidationFailure("Ordering must be defined on a time attribute.")
}
ValidationSuccess
}
}
......
......@@ -19,8 +19,9 @@ package org.apache.flink.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.table.api.{UnresolvedException, ValidationException}
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
trait NamedExpression extends Expression {
......@@ -116,24 +117,6 @@ case class UnresolvedAlias(child: Expression) extends UnaryExpression with Named
override private[flink] lazy val valid = false
}
case class RowtimeAttribute() extends Attribute {
override private[flink] def withName(newName: String): Attribute = {
if (newName == "rowtime") {
this
} else {
throw new ValidationException("Cannot rename streaming rowtime attribute.")
}
}
override private[flink] def name: String = "rowtime"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
throw new UnsupportedOperationException("A rowtime attribute can not be used solely.")
}
override private[flink] def resultType: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO
}
case class WindowReference(name: String) extends Attribute {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
......@@ -150,3 +133,30 @@ case class WindowReference(name: String) extends Attribute {
}
}
}
abstract class TimeAttribute(val expression: Expression)
extends UnaryExpression
with NamedExpression {
override private[flink] def child: Expression = expression
override private[flink] def name: String = expression match {
case UnresolvedFieldReference(name) => name
case _ => throw new ValidationException("Unresolved field reference expected.")
}
override private[flink] def toAttribute: Attribute =
throw new UnsupportedOperationException("Time attribute can not be used solely.")
}
case class RowtimeAttribute(expr: Expression) extends TimeAttribute(expr) {
override private[flink] def resultType: TypeInformation[_] =
TimeIndicatorTypeInfo.ROWTIME_INDICATOR
}
case class ProctimeAttribute(expr: Expression) extends TimeAttribute(expr) {
override private[flink] def resultType: TypeInformation[_] =
TimeIndicatorTypeInfo.PROCTIME_INDICATOR
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.functions
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`._
import org.apache.calcite.sql.validate.SqlMonotonicity
/**
* Function that materializes a time attribute to the metadata timestamp. After materialization
* the result can be used in regular arithmetical calculations.
*/
object TimeMaterializationSqlFunction
extends SqlFunction(
"TIME_MATERIALIZATION",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.TIMESTAMP),
InferTypes.RETURN_TYPE,
OperandTypes.family(SqlTypeFamily.TIMESTAMP),
SqlFunctionCategory.SYSTEM) {
override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION
override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity =
SqlMonotonicity.INCREASING
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.functions
import java.nio.charset.Charset
import java.util
import org.apache.calcite.rel.`type`._
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`.{OperandTypes, ReturnTypes, SqlTypeFamily, SqlTypeName}
import org.apache.calcite.sql.validate.SqlMonotonicity
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo
import org.apache.flink.table.api.TableException
import org.apache.flink.table.expressions.LeafExpression
object EventTimeExtractor extends SqlFunction("ROWTIME", SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(TimeModeTypes.ROWTIME), null, OperandTypes.NILADIC,
SqlFunctionCategory.SYSTEM) {
override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION
override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity =
SqlMonotonicity.INCREASING
}
object ProcTimeExtractor extends SqlFunction("PROCTIME", SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(TimeModeTypes.PROCTIME), null, OperandTypes.NILADIC,
SqlFunctionCategory.SYSTEM) {
override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION
override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity =
SqlMonotonicity.INCREASING
}
abstract class TimeIndicator extends LeafExpression {
/**
* Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]]
* for evaluating this expression.
* It is sometimes not available until the expression is valid.
*/
override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP
/**
* Convert Expression to its counterpart in Calcite, i.e. RexNode
*/
override private[flink] def toRexNode(implicit relBuilder: RelBuilder) =
throw new TableException("indicator functions (e.g. proctime() and rowtime()" +
" are not executable. Please check your expressions.")
}
case class RowTime() extends TimeIndicator
case class ProcTime() extends TimeIndicator
object TimeModeTypes {
// indicator data type for row time (event time)
val ROWTIME = new RowTimeType
// indicator data type for processing time
val PROCTIME = new ProcTimeType
}
class RowTimeType extends TimeModeType {
override def toString(): String = "ROWTIME"
override def getFullTypeString: String = "ROWTIME_INDICATOR"
}
class ProcTimeType extends TimeModeType {
override def toString(): String = "PROCTIME"
override def getFullTypeString: String = "PROCTIME_INDICATOR"
}
abstract class TimeModeType extends RelDataType {
override def getComparability: RelDataTypeComparability = RelDataTypeComparability.NONE
override def isStruct: Boolean = false
override def getFieldList: util.List[RelDataTypeField] = null
override def getFieldNames: util.List[String] = null
override def getFieldCount: Int = 0
override def getStructKind: StructKind = StructKind.NONE
override def getField(
fieldName: String,
caseSensitive: Boolean,
elideRecord: Boolean): RelDataTypeField = null
override def isNullable: Boolean = false
override def getComponentType: RelDataType = null
override def getKeyType: RelDataType = null
override def getValueType: RelDataType = null
override def getCharset: Charset = null
override def getCollation: SqlCollation = null
override def getIntervalQualifier: SqlIntervalQualifier = null
override def getPrecision: Int = -1
override def getScale: Int = -1
override def getSqlTypeName: SqlTypeName = SqlTypeName.TIMESTAMP
override def getSqlIdentifier: SqlIdentifier = null
override def getFamily: RelDataTypeFamily = SqlTypeFamily.NUMERIC
override def getPrecedenceList: RelDataTypePrecedenceList = ???
override def isDynamicStruct: Boolean = false
}
......@@ -19,9 +19,8 @@
package org.apache.flink.table.plan
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.table.api.{OverWindow, StreamTableEnvironment, TableEnvironment}
import org.apache.flink.table.api.{OverWindow, TableEnvironment}
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.{ProcTime, RowTime}
import org.apache.flink.table.plan.logical.{LogicalNode, Project}
import scala.collection.mutable
......@@ -231,28 +230,12 @@ object ProjectionTranslator {
val overWindow = overWindows.find(_.alias.equals(unresolvedCall.alias))
if (overWindow.isDefined) {
if (tEnv.isInstanceOf[StreamTableEnvironment]) {
val timeIndicator = overWindow.get.orderBy match {
case u: UnresolvedFieldReference if u.name.toLowerCase == "rowtime" =>
RowTime()
case u: UnresolvedFieldReference if u.name.toLowerCase == "proctime" =>
ProcTime()
case e: Expression => e
}
OverCall(
unresolvedCall.agg,
overWindow.get.partitionBy,
timeIndicator,
overWindow.get.preceding,
overWindow.get.following)
} else {
OverCall(
unresolvedCall.agg,
overWindow.get.partitionBy,
overWindow.get.orderBy,
overWindow.get.preceding,
overWindow.get.following)
}
OverCall(
unresolvedCall.agg,
overWindow.get.partitionBy,
overWindow.get.orderBy,
overWindow.get.preceding,
overWindow.get.following)
} else {
unresolvedCall
}
......
......@@ -22,14 +22,24 @@ import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.expressions.{Expression, WindowReference}
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
abstract class LogicalWindow(val alias: Expression) extends Resolvable[LogicalWindow] {
/**
* Logical super class for all types of windows (group-windows and row-windows).
*
* @param aliasAttribute window alias
* @param timeAttribute time field indicating event-time or processing-time
*/
abstract class LogicalWindow(
val aliasAttribute: Expression,
val timeAttribute: Expression)
extends Resolvable[LogicalWindow] {
def resolveExpressions(resolver: (Expression) => Expression): LogicalWindow = this
def validate(tableEnv: TableEnvironment): ValidationResult = alias match {
def validate(tableEnv: TableEnvironment): ValidationResult = aliasAttribute match {
case WindowReference(_) => ValidationSuccess
case _ => ValidationFailure("Window reference for window expected.")
}
override def toString: String = getClass.getSimpleName
}
......@@ -18,259 +18,165 @@
package org.apache.flink.table.plan.logical
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.table.api.{BatchTableEnvironment, StreamTableEnvironment, TableEnvironment}
import org.apache.flink.table.expressions.ExpressionUtils.{isRowCountLiteral, isRowtimeAttribute, isTimeAttribute, isTimeIntervalLiteral}
import org.apache.flink.table.expressions._
import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo, TypeCoercion}
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimePoint
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
abstract class EventTimeGroupWindow(
alias: Expression,
time: Expression)
extends LogicalWindow(alias) {
override def validate(tableEnv: TableEnvironment): ValidationResult = {
val valid = super.validate(tableEnv)
if (valid.isFailure) {
return valid
}
tableEnv match {
case _: StreamTableEnvironment =>
time match {
case RowtimeAttribute() =>
ValidationSuccess
case _ =>
ValidationFailure("Event-time window expects a 'rowtime' time field.")
}
case _: BatchTableEnvironment =>
if (!TypeCoercion.canCast(time.resultType, BasicTypeInfo.LONG_TYPE_INFO)) {
ValidationFailure(s"Event-time window expects a time field that can be safely cast " +
s"to Long, but is ${time.resultType}")
} else {
ValidationSuccess
}
}
}
}
abstract class ProcessingTimeGroupWindow(alias: Expression) extends LogicalWindow(alias) {
override def validate(tableEnv: TableEnvironment): ValidationResult = {
val valid = super.validate(tableEnv)
if (valid.isFailure) {
return valid
}
tableEnv match {
case b: BatchTableEnvironment => ValidationFailure(
"Window on batch must declare a time attribute over which the query is evaluated.")
case _ =>
ValidationSuccess
}
}
}
// ------------------------------------------------------------------------------------------------
// Tumbling group windows
// ------------------------------------------------------------------------------------------------
object TumblingGroupWindow {
def validate(tableEnv: TableEnvironment, size: Expression): ValidationResult = size match {
case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
ValidationSuccess
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
ValidationSuccess
case _ =>
ValidationFailure("Tumbling window expects size literal of type Interval of Milliseconds " +
"or Interval of Rows.")
}
}
case class ProcessingTimeTumblingGroupWindow(
override val alias: Expression,
size: Expression)
extends ProcessingTimeGroupWindow(alias) {
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
ProcessingTimeTumblingGroupWindow(
resolve(alias),
resolve(size))
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv).orElse(TumblingGroupWindow.validate(tableEnv, size))
override def toString: String = s"ProcessingTimeTumblingGroupWindow($alias, $size)"
}
case class EventTimeTumblingGroupWindow(
override val alias: Expression,
case class TumblingGroupWindow(
alias: Expression,
timeField: Expression,
size: Expression)
extends EventTimeGroupWindow(
extends LogicalWindow(
alias,
timeField) {
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
EventTimeTumblingGroupWindow(
TumblingGroupWindow(
resolve(alias),
resolve(timeField),
resolve(size))
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv)
.orElse(TumblingGroupWindow.validate(tableEnv, size))
.orElse(size match {
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS)
if tableEnv.isInstanceOf[StreamTableEnvironment] =>
super.validate(tableEnv).orElse(
tableEnv match {
// check size
case _ if !isTimeIntervalLiteral(size) && !isRowCountLiteral(size) =>
ValidationFailure(
"Tumbling window expects size literal of type Interval of Milliseconds " +
"or Interval of Rows.")
// check time attribute
case _: StreamTableEnvironment if !isTimeAttribute(timeField) =>
ValidationFailure(
"Tumbling window expects a time attribute for grouping in a stream environment.")
case _: BatchTableEnvironment if isTimePoint(size.resultType) =>
ValidationFailure(
"Tumbling window expects a time attribute for grouping in a stream environment.")
// check row intervals on event-time
case _: StreamTableEnvironment
if isRowCountLiteral(size) && isRowtimeAttribute(timeField) =>
ValidationFailure(
"Event-time grouping windows on row intervals in a stream environment " +
"are currently not supported.")
case _ =>
ValidationSuccess
})
}
)
override def toString: String = s"EventTimeTumblingGroupWindow($alias, $timeField, $size)"
override def toString: String = s"TumblingGroupWindow($alias, $timeField, $size)"
}
// ------------------------------------------------------------------------------------------------
// Sliding group windows
// ------------------------------------------------------------------------------------------------
object SlidingGroupWindow {
def validate(
tableEnv: TableEnvironment,
size: Expression,
slide: Expression)
: ValidationResult = {
val checkedSize = size match {
case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
ValidationSuccess
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
ValidationSuccess
case _ =>
ValidationFailure("Sliding window expects size literal of type Interval of " +
"Milliseconds or Interval of Rows.")
}
val checkedSlide = slide match {
case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
ValidationSuccess
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) =>
ValidationSuccess
case _ =>
ValidationFailure("Sliding window expects slide literal of type Interval of " +
"Milliseconds or Interval of Rows.")
}
checkedSize
.orElse(checkedSlide)
.orElse {
if (size.resultType != slide.resultType) {
ValidationFailure("Sliding window expects same type of size and slide.")
} else {
ValidationSuccess
}
}
}
}
case class ProcessingTimeSlidingGroupWindow(
override val alias: Expression,
case class SlidingGroupWindow(
alias: Expression,
timeField: Expression,
size: Expression,
slide: Expression)
extends ProcessingTimeGroupWindow(alias) {
extends LogicalWindow(
alias,
timeField) {
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
ProcessingTimeSlidingGroupWindow(
SlidingGroupWindow(
resolve(alias),
resolve(timeField),
resolve(size),
resolve(slide))
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv).orElse(SlidingGroupWindow.validate(tableEnv, size, slide))
super.validate(tableEnv).orElse(
tableEnv match {
override def toString: String = s"ProcessingTimeSlidingGroupWindow($alias, $size, $slide)"
}
// check size
case _ if !isTimeIntervalLiteral(size) && !isRowCountLiteral(size) =>
ValidationFailure(
"Sliding window expects size literal of type Interval of Milliseconds " +
"or Interval of Rows.")
case class EventTimeSlidingGroupWindow(
override val alias: Expression,
timeField: Expression,
size: Expression,
slide: Expression)
extends EventTimeGroupWindow(alias, timeField) {
// check slide
case _ if !isTimeIntervalLiteral(slide) && !isRowCountLiteral(slide) =>
ValidationFailure(
"Sliding window expects slide literal of type Interval of Milliseconds " +
"or Interval of Rows.")
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
EventTimeSlidingGroupWindow(
resolve(alias),
resolve(timeField),
resolve(size),
resolve(slide))
// check same type of intervals
case _ if isTimeIntervalLiteral(size) != isTimeIntervalLiteral(slide) =>
ValidationFailure("Sliding window expects same type of size and slide.")
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv)
.orElse(SlidingGroupWindow.validate(tableEnv, size, slide))
.orElse(size match {
case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS)
if tableEnv.isInstanceOf[StreamTableEnvironment] =>
// check time attribute
case _: StreamTableEnvironment if !isTimeAttribute(timeField) =>
ValidationFailure(
"Sliding window expects a time attribute for grouping in a stream environment.")
case _: BatchTableEnvironment if isTimePoint(size.resultType) =>
ValidationFailure(
"Sliding window expects a time attribute for grouping in a stream environment.")
// check row intervals on event-time
case _: StreamTableEnvironment
if isRowCountLiteral(size) && isRowtimeAttribute(timeField) =>
ValidationFailure(
"Event-time grouping windows on row intervals in a stream environment " +
"are currently not supported.")
case _ =>
ValidationSuccess
})
}
)
override def toString: String = s"EventTimeSlidingGroupWindow($alias, $timeField, $size, $slide)"
override def toString: String = s"SlidingGroupWindow($alias, $timeField, $size, $slide)"
}
// ------------------------------------------------------------------------------------------------
// Session group windows
// ------------------------------------------------------------------------------------------------
object SessionGroupWindow {
def validate(tableEnv: TableEnvironment, gap: Expression): ValidationResult = gap match {
case Literal(timeInterval: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
ValidationSuccess
case _ =>
ValidationFailure(
"Session window expects gap literal of type Interval of Milliseconds.")
}
}
case class ProcessingTimeSessionGroupWindow(
override val alias: Expression,
gap: Expression)
extends ProcessingTimeGroupWindow(alias) {
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
ProcessingTimeSessionGroupWindow(
resolve(alias),
resolve(gap))
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap))
override def toString: String = s"ProcessingTimeSessionGroupWindow($alias, $gap)"
}
case class EventTimeSessionGroupWindow(
override val alias: Expression,
case class SessionGroupWindow(
alias: Expression,
timeField: Expression,
gap: Expression)
extends EventTimeGroupWindow(
extends LogicalWindow(
alias,
timeField) {
override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow =
EventTimeSessionGroupWindow(
SessionGroupWindow(
resolve(alias),
resolve(timeField),
resolve(gap))
override def validate(tableEnv: TableEnvironment): ValidationResult =
super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap))
super.validate(tableEnv).orElse(
tableEnv match {
// check size
case _ if !isTimeIntervalLiteral(gap) =>
ValidationFailure(
"Session window expects size literal of type Interval of Milliseconds.")
// check time attribute
case _: StreamTableEnvironment if !isTimeAttribute(timeField) =>
ValidationFailure(
"Session window expects a time attribute for grouping in a stream environment.")
case _: BatchTableEnvironment if isTimePoint(gap.resultType) =>
ValidationFailure(
"Session window expects a time attribute for grouping in a stream environment.")
case _ =>
ValidationSuccess
}
)
override def toString: String = s"EventTimeSessionGroupWindow($alias, $timeField, $gap)"
override def toString: String = s"SessionGroupWindow($alias, $timeField, $gap)"
}
......@@ -70,8 +70,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend
def checkName(name: String): Unit = {
if (names.contains(name)) {
failValidation(s"Duplicate field name $name.")
} else if (tableEnv.isInstanceOf[StreamTableEnvironment] && name == "rowtime") {
failValidation("'rowtime' cannot be used as field name in a streaming environment.")
} else {
names.add(name)
}
......@@ -112,10 +110,6 @@ case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends Una
failValidation("Alias only accept name expressions as arguments")
} else if (!aliasList.forall(_.asInstanceOf[UnresolvedFieldReference].name != "*")) {
failValidation("Alias can not accept '*' as name")
} else if (tableEnv.isInstanceOf[StreamTableEnvironment] && !aliasList.forall {
case UnresolvedFieldReference(name) => name != "rowtime"
}) {
failValidation("'rowtime' cannot be used as field name in a streaming environment.")
} else {
val names = aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name)
val input = child.output
......@@ -561,26 +555,20 @@ case class WindowAggregate(
override def resolveReference(
tableEnv: TableEnvironment,
name: String)
: Option[NamedExpression] = tableEnv match {
// resolve reference to rowtime attribute in a streaming environment
case _: StreamTableEnvironment if name == "rowtime" =>
Some(RowtimeAttribute())
case _ =>
window.alias match {
// resolve reference to this window's alias
case UnresolvedFieldReference(alias) if name == alias =>
// check if reference can already be resolved by input fields
val found = super.resolveReference(tableEnv, name)
if (found.isDefined) {
failValidation(s"Reference $name is ambiguous.")
} else {
Some(WindowReference(name))
}
case _ =>
// resolve references as usual
super.resolveReference(tableEnv, name)
}
}
: Option[NamedExpression] = window.aliasAttribute match {
// resolve reference to this window's name
case UnresolvedFieldReference(alias) if name == alias =>
// check if reference can already be resolved by input fields
val found = super.resolveReference(tableEnv, name)
if (found.isDefined) {
failValidation(s"Reference $name is ambiguous.")
} else {
Some(WindowReference(name))
}
case _ =>
// resolve references as usual
super.resolveReference(tableEnv, name)
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder]
......
......@@ -19,13 +19,12 @@
package org.apache.flink.table.plan.nodes
import org.apache.calcite.plan.{RelOptCost, RelOptPlanner}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.{CodeGenerator, GeneratedFunction}
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.FlatMapRunner
import org.apache.flink.types.Row
......@@ -35,21 +34,30 @@ import scala.collection.JavaConverters._
trait CommonCalc {
private[flink] def functionBody(
generator: CodeGenerator,
inputType: TypeInformation[Row],
rowType: RelDataType,
calcProgram: RexProgram,
config: TableConfig)
generator: CodeGenerator,
inputSchema: RowSchema,
returnSchema: RowSchema,
calcProgram: RexProgram,
config: TableConfig)
: String = {
val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType)
val expandedExpressions = calcProgram
.getProjectList
.map(expr => calcProgram.expandLocalRef(expr))
// time indicator fields must not be part of the code generation
.filter(expr => !FlinkTypeFactory.isTimeIndicatorType(expr.getType))
// update indices
.map(expr => inputSchema.mapRexNode(expr))
val condition = if (calcProgram.getCondition != null) {
inputSchema.mapRexNode(calcProgram.expandLocalRef(calcProgram.getCondition))
} else {
null
}
val condition = calcProgram.getCondition
val expandedExpressions = calcProgram.getProjectList.map(
expr => calcProgram.expandLocalRef(expr))
val projection = generator.generateResultExpression(
returnType,
rowType.getFieldNames,
returnSchema.physicalTypeInfo,
returnSchema.physicalFieldNames,
expandedExpressions)
// only projection
......@@ -60,8 +68,7 @@ trait CommonCalc {
|""".stripMargin
}
else {
val filterCondition = generator.generateExpression(
calcProgram.expandLocalRef(calcProgram.getCondition))
val filterCondition = generator.generateExpression(condition)
// only filter
if (projection == null) {
s"""
......
......@@ -23,11 +23,11 @@ import org.apache.calcite.sql.SemiJoinType
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue
import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE}
import org.apache.flink.table.codegen.{CodeGenerator, GeneratedCollector, GeneratedExpression, GeneratedFunction}
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.{CorrelateFlatMapRunner, TableFunctionCollector}
import org.apache.flink.types.Row
......@@ -44,9 +44,9 @@ trait CommonCorrelate {
*/
private[flink] def correlateMapFunction(
config: TableConfig,
inputTypeInfo: TypeInformation[Row],
inputSchema: RowSchema,
udtfTypeInfo: TypeInformation[Any],
rowType: RelDataType,
returnSchema: RowSchema,
joinType: SemiJoinType,
rexCall: RexCall,
condition: Option[RexNode],
......@@ -54,26 +54,24 @@ trait CommonCorrelate {
ruleDescription: String)
: CorrelateFlatMapRunner[Row, Row] = {
val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType)
val flatMap = generateFunction(
config,
inputTypeInfo,
inputSchema.physicalTypeInfo,
udtfTypeInfo,
returnType,
rowType,
returnSchema.physicalTypeInfo,
returnSchema.logicalFieldNames,
joinType,
rexCall,
inputSchema.mapRexNode(rexCall).asInstanceOf[RexCall],
pojoFieldMapping,
ruleDescription)
val collector = generateCollector(
config,
inputTypeInfo,
inputSchema.physicalTypeInfo,
udtfTypeInfo,
returnType,
rowType,
condition,
returnSchema.physicalTypeInfo,
returnSchema.logicalFieldNames,
condition.map(inputSchema.mapRexNode),
pojoFieldMapping)
new CorrelateFlatMapRunner[Row, Row](
......@@ -93,7 +91,7 @@ trait CommonCorrelate {
inputTypeInfo: TypeInformation[Row],
udtfTypeInfo: TypeInformation[Any],
returnType: TypeInformation[Row],
rowType: RelDataType,
resultFieldNames: Seq[String],
joinType: SemiJoinType,
rexCall: RexCall,
pojoFieldMapping: Option[Array[Int]],
......@@ -134,7 +132,7 @@ trait CommonCorrelate {
x.resultType)
}
val outerResultExpr = functionGenerator.generateResultExpression(
input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala)
input1AccessExprs ++ input2NullExprs, returnType, resultFieldNames)
body +=
s"""
|boolean hasOutput = $collectorTerm.isCollected();
......@@ -162,7 +160,7 @@ trait CommonCorrelate {
inputTypeInfo: TypeInformation[Row],
udtfTypeInfo: TypeInformation[Any],
returnType: TypeInformation[Row],
rowType: RelDataType,
resultFieldNames: Seq[String],
condition: Option[RexNode],
pojoFieldMapping: Option[Array[Int]])
: GeneratedCollector = {
......@@ -180,7 +178,7 @@ trait CommonCorrelate {
val crossResultExpr = generator.generateResultExpression(
input1AccessExprs ++ input2AccessExprs,
returnType,
rowType.getFieldNames.asScala)
resultFieldNames)
val collectorCode = if (condition.isEmpty) {
s"""
......
......@@ -18,11 +18,10 @@
package org.apache.flink.table.plan.nodes
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.functions.Function
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.runtime.MapRunner
import org.apache.flink.table.codegen.{CodeGenerator, GeneratedFunction}
import org.apache.flink.types.Row
/**
......@@ -42,21 +41,22 @@ trait CommonScan {
externalTypeInfo != internalTypeInfo
}
private[flink] def getConversionMapper(
private[flink] def generatedConversionFunction[F <: Function](
config: TableConfig,
functionClass: Class[F],
inputType: TypeInformation[Any],
expectedType: TypeInformation[Row],
conversionOperatorName: String,
fieldNames: Seq[String],
inputPojoFieldMapping: Option[Array[Int]] = None)
: MapFunction[Any, Row] = {
inputFieldMapping: Option[Array[Int]] = None)
: GeneratedFunction[F, Row] = {
val generator = new CodeGenerator(
config,
false,
inputType,
None,
inputPojoFieldMapping)
inputFieldMapping)
val conversion = generator.generateConverterResultExpression(expectedType, fieldNames)
val body =
......@@ -65,17 +65,11 @@ trait CommonScan {
|return ${conversion.resultTerm};
|""".stripMargin
val genFunction = generator.generateFunction(
generator.generateFunction(
conversionOperatorName,
classOf[MapFunction[Any, Row]],
functionClass,
body,
expectedType)
new MapRunner[Any, Row](
genFunction.name,
genFunction.code,
genFunction.returnType)
}
}
......@@ -18,14 +18,13 @@
package org.apache.flink.table.plan.nodes
import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl}
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rex.{RexInputRef}
import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rex.RexInputRef
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.aggregate.AggregateUtil._
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
import scala.collection.JavaConverters._
......@@ -43,7 +42,7 @@ trait OverAggregate {
val inFields = inputType.getFieldList.asScala
val orderingString = orderFields.asScala.map {
x => inFields(x.getFieldIndex).getValue
x => inFields(x.getFieldIndex).getName
}.mkString(", ")
orderingString
......@@ -66,24 +65,8 @@ trait OverAggregate {
rowType: RelDataType,
namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = {
val inFields = inputType.getFieldList.asScala.map {
x =>
x.asInstanceOf[RelDataTypeFieldImpl].getType
match {
case proceTime: ProcTimeType => "PROCTIME"
case rowTime: RowTimeType => "ROWTIME"
case _ => x.asInstanceOf[RelDataTypeFieldImpl].getName
}
}
val outFields = rowType.getFieldList.asScala.map {
x =>
x.asInstanceOf[RelDataTypeFieldImpl].getType
match {
case proceTime: ProcTimeType => "PROCTIME"
case rowTime: RowTimeType => "ROWTIME"
case _ => x.asInstanceOf[RelDataTypeFieldImpl].getName
}
}
val inFields = inputType.getFieldNames.asScala
val outFields = rowType.getFieldNames.asScala
val aggStrings = namedAggregates.map(_.getKey).map(
a => s"${a.getAggregation}(${
......@@ -109,7 +92,7 @@ trait OverAggregate {
input: RelNode): Long = {
val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex;
val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex
val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
lowerBound match {
case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue()
......
......@@ -37,9 +37,11 @@ abstract class PhysicalTableSourceScan(
override def deriveRowType(): RelDataType = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildRowDataType(
flinkTypeFactory.buildLogicalRowType(
TableEnvironment.getFieldNames(tableSource),
TableEnvironment.getFieldTypes(tableSource.getReturnType))
TableEnvironment.getFieldTypes(tableSource.getReturnType),
None,
None)
}
override def explainTerms(pw: RelWriter): RelWriter = {
......
......@@ -18,11 +18,13 @@
package org.apache.flink.table.plan.nodes.dataset
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.java.DataSet
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.nodes.CommonScan
import org.apache.flink.table.plan.schema.FlinkTable
import org.apache.flink.table.runtime.MapRunner
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
......@@ -43,17 +45,23 @@ trait BatchScan extends CommonScan with DataSetRel {
// conversion
if (needsConversion(inputType, internalType)) {
val mapFunc = getConversionMapper(
val function = generatedConversionFunction(
config,
classOf[MapFunction[Any, Row]],
inputType,
internalType,
"DataSetSourceConversion",
getRowType.getFieldNames,
Some(flinkTable.fieldIndexes))
val runner = new MapRunner[Any, Row](
function.name,
function.code,
function.returnType)
val opName = s"from: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})"
input.map(mapFunc).name(opName)
input.map(runner).name(opName)
}
// no conversion necessary, forward
else {
......
......@@ -22,7 +22,8 @@ import org.apache.calcite.plan._
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.flink.api.java.DataSet
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.api.{BatchTableEnvironment, TableEnvironment}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.nodes.PhysicalTableSourceScan
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.{BatchTableSource, TableSource}
......@@ -37,7 +38,16 @@ class BatchTableSourceScan(
extends PhysicalTableSourceScan(cluster, traitSet, table, tableSource)
with BatchScan {
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildLogicalRowType(
TableEnvironment.getFieldNames(tableSource),
TableEnvironment.getFieldTypes(tableSource.getReturnType),
None,
None)
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val rowCnt = metadata.getRowCount(this)
planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType))
}
......
......@@ -91,6 +91,9 @@ class DataSetAggregate(
override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val input = inputNode.asInstanceOf[DataSetRel]
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val generator = new CodeGenerator(
tableEnv.getConfig,
......@@ -104,15 +107,14 @@ class DataSetAggregate(
) = AggregateUtil.createDataSetAggregateFunctions(
generator,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
inGroupingSet)
val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
if (grouping.length > 0) {
// grouped aggregation
val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
......
......@@ -26,10 +26,13 @@ import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rex._
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.nodes.CommonCalc
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.FlatMapRunner
import org.apache.flink.types.Row
/**
......@@ -83,14 +86,14 @@ class DataSetCalc(
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
val generator = new CodeGenerator(config, false, inputDS.getType)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val body = functionBody(
generator,
inputDS.getType,
getRowType,
new RowSchema(getInput.getRowType),
new RowSchema(getRowType),
calcProgram,
config)
......@@ -98,9 +101,13 @@ class DataSetCalc(
ruleDescription,
classOf[FlatMapFunction[Row, Row]],
body,
returnType)
rowTypeInfo)
val runner = new FlatMapRunner[Row, Row](
genFunction.name,
genFunction.code,
genFunction.returnType)
val mapFunc = calcMapFunction(genFunction)
inputDS.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString))
inputDS.flatMap(runner).name(calcOpName(calcProgram, getExpressionString))
}
}
......@@ -25,10 +25,13 @@ import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.sql.SemiJoinType
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.nodes.CommonCorrelate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.types.Row
/**
......@@ -98,11 +101,13 @@ class DataSetCorrelate(
val pojoFieldMapping = sqlFunction.getPojoFieldMapping
val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val mapFunc = correlateMapFunction(
config,
inputDS.getType,
new RowSchema(getInput.getRowType),
udtfTypeInfo,
getRowType,
new RowSchema(getRowType),
joinType,
rexCall,
condition,
......
......@@ -24,15 +24,16 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo}
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.expressions.ExpressionUtils._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.CommonAggregate
import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _}
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.typeutils.TypeCheckUtils.{isLong, isTimePoint}
import org.apache.flink.types.Row
/**
......@@ -106,8 +107,6 @@ class DataSetWindowAggregate(
override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {
val config = tableEnv.getConfig
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val generator = new CodeGenerator(
......@@ -119,30 +118,31 @@ class DataSetWindowAggregate(
val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive()
window match {
case EventTimeTumblingGroupWindow(_, _, size) =>
case TumblingGroupWindow(_, timeField, size)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeTumblingWindowDataSet(
generator,
inputDS,
isTimeInterval(size.resultType),
isTimeIntervalLiteral(size),
caseSensitive)
case EventTimeSessionGroupWindow(_, _, gap) =>
case SessionGroupWindow(_, timeField, gap)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive)
case EventTimeSlidingGroupWindow(_, _, size, slide) =>
case SlidingGroupWindow(_, timeField, size, slide)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeSlidingWindowDataSet(
generator,
inputDS,
isTimeInterval(size.resultType),
isTimeIntervalLiteral(size),
asLong(size),
asLong(slide),
caseSensitive)
case _: ProcessingTimeGroupWindow =>
case _ =>
throw new UnsupportedOperationException(
"Processing-time tumbling windows are not supported in a batch environment, " +
"windows in a batch environment must declare a time attribute over which " +
"the query is evaluated.")
s"Window $window is not supported in a batch environment.")
}
}
......@@ -152,18 +152,22 @@ class DataSetWindowAggregate(
isTimeWindow: Boolean,
isParserCaseSensitive: Boolean): DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
val mapFunction = createDataSetWindowPrepareMapFunction(
generator,
window,
namedAggregates,
grouping,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive)
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
getRowType,
grouping,
namedProperties)
......@@ -210,6 +214,8 @@ class DataSetWindowAggregate(
inputDS: DataSet[Row],
isParserCaseSensitive: Boolean): DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
val groupingKeys = grouping.indices.toArray
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
......@@ -219,7 +225,8 @@ class DataSetWindowAggregate(
window,
namedAggregates,
grouping,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive)
val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
......@@ -245,7 +252,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping)
// create groupReduceFunction for calculating the aggregations
......@@ -253,7 +261,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
......@@ -275,7 +284,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping)
// create groupReduceFunction for calculating the aggregations
......@@ -283,7 +293,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
......@@ -308,7 +319,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties)
......@@ -324,7 +336,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties)
......@@ -347,6 +360,8 @@ class DataSetWindowAggregate(
isParserCaseSensitive: Boolean)
: DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
// create MapFunction for initializing the aggregations
// it aligns the rowtime for pre-tumbling in case of a time-window for partial aggregates
val mapFunction = createDataSetWindowPrepareMapFunction(
......@@ -354,7 +369,8 @@ class DataSetWindowAggregate(
window,
namedAggregates,
grouping,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive)
val mappedDataSet = inputDS
......@@ -390,7 +406,8 @@ class DataSetWindowAggregate(
window,
namedAggregates,
grouping,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive)
mappedDataSet.asInstanceOf[DataSet[Row]]
......@@ -426,7 +443,8 @@ class DataSetWindowAggregate(
generator,
window,
namedAggregates,
inputType,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
......
......@@ -25,20 +25,18 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream}
import org.apache.flink.streaming.api.windowing.assigners._
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.expressions._
import org.apache.flink.table.expressions.ExpressionUtils._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.CommonAggregate
import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate._
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.aggregate.AggregateUtil._
import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo}
import org.apache.flink.types.Row
class DataStreamAggregate(
......@@ -48,12 +46,12 @@ class DataStreamAggregate(
traitSet: RelTraitSet,
inputNode: RelNode,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowRelDataType: RelDataType,
inputType: RelDataType,
schema: RowSchema,
inputSchema: RowSchema,
grouping: Array[Int])
extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel {
override def deriveRowType(): RelDataType = rowRelDataType
override def deriveRowType(): RelDataType = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamAggregate(
......@@ -63,22 +61,22 @@ class DataStreamAggregate(
traitSet,
inputs.get(0),
namedAggregates,
getRowType,
inputType,
schema,
inputSchema,
grouping)
}
override def toString: String = {
s"Aggregate(${
if (!grouping.isEmpty) {
s"groupBy: (${groupingToString(inputType, grouping)}), "
s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), "
} else {
""
}
}window: ($window), " +
s"select: (${
aggregationToString(
inputType,
inputSchema.logicalType,
grouping,
getRowType,
namedAggregates,
......@@ -88,13 +86,13 @@ class DataStreamAggregate(
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty)
.itemIf("groupBy", groupingToString(inputSchema.logicalType, grouping), !grouping.isEmpty)
.item("window", window)
.item(
"select", aggregationToString(
inputType,
inputSchema.logicalType,
grouping,
getRowType,
schema.logicalType,
namedAggregates,
namedProperties))
}
......@@ -102,17 +100,20 @@ class DataStreamAggregate(
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
val physicalNamedAggregates = namedAggregates.map { namedAggregate =>
new CalcitePair[AggregateCall, String](
inputSchema.mapAggregateCall(namedAggregate.left),
namedAggregate.right)
}
val aggString = aggregationToString(
inputType,
inputSchema.logicalType,
grouping,
getRowType,
schema.logicalType,
namedAggregates,
namedProperties)
val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
val keyedAggOpName = s"groupBy: (${groupingToString(schema.logicalType, grouping)}), " +
s"window: ($window), " +
s"select: ($aggString)"
val nonKeyedAggOpName = s"window: ($window), select: ($aggString)"
......@@ -123,21 +124,21 @@ class DataStreamAggregate(
inputDS.getType)
val needMerge = window match {
case ProcessingTimeSessionGroupWindow(_, _) => true
case EventTimeSessionGroupWindow(_, _, _) => true
case SessionGroupWindow(_, _, _) => true
case _ => false
}
val physicalGrouping = grouping.map(inputSchema.mapIndex)
// grouped / keyed aggregation
if (grouping.length > 0) {
if (physicalGrouping.length > 0) {
val windowFunction = AggregateUtil.createAggregationGroupWindowFunction(
window,
grouping.length,
namedAggregates.size,
rowRelDataType.getFieldCount,
physicalGrouping.length,
physicalNamedAggregates.size,
schema.physicalArity,
namedProperties)
val keyedStream = inputDS.keyBy(grouping: _*)
val keyedStream = inputDS.keyBy(physicalGrouping: _*)
val windowedStream =
createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]
......@@ -145,20 +146,26 @@ class DataStreamAggregate(
val (aggFunction, accumulatorRowType, aggResultRowType) =
AggregateUtil.createDataStreamAggregateFunction(
generator,
namedAggregates,
inputType,
rowRelDataType,
physicalNamedAggregates,
inputSchema.physicalType,
inputSchema.physicalFieldTypeInfo,
schema.physicalType,
needMerge)
windowedStream
.aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
.aggregate(
aggFunction,
windowFunction,
accumulatorRowType,
aggResultRowType,
schema.physicalTypeInfo)
.name(keyedAggOpName)
}
// global / non-keyed aggregation
else {
val windowFunction = AggregateUtil.createAggregationAllWindowFunction(
window,
rowRelDataType.getFieldCount,
schema.physicalArity,
namedProperties)
val windowedStream =
......@@ -168,13 +175,19 @@ class DataStreamAggregate(
val (aggFunction, accumulatorRowType, aggResultRowType) =
AggregateUtil.createDataStreamAggregateFunction(
generator,
namedAggregates,
inputType,
rowRelDataType,
physicalNamedAggregates,
inputSchema.physicalType,
inputSchema.physicalFieldTypeInfo,
schema.physicalType,
needMerge)
windowedStream
.aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
.aggregate(
aggFunction,
windowFunction,
accumulatorRowType,
aggResultRowType,
schema.physicalTypeInfo)
.name(nonKeyedAggOpName)
}
}
......@@ -186,95 +199,102 @@ object DataStreamAggregate {
private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple])
: WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match {
case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) =>
stream.window(TumblingProcessingTimeWindows.of(asTime(size)))
case TumblingGroupWindow(_, timeField, size)
if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size)=>
stream.window(TumblingProcessingTimeWindows.of(toTime(size)))
case ProcessingTimeTumblingGroupWindow(_, size) =>
stream.countWindow(asCount(size))
case TumblingGroupWindow(_, timeField, size)
if isProctimeAttribute(timeField) && isRowCountLiteral(size)=>
stream.countWindow(toLong(size))
case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
stream.window(TumblingEventTimeWindows.of(asTime(size)))
case TumblingGroupWindow(_, timeField, size)
if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) =>
stream.window(TumblingEventTimeWindows.of(toTime(size)))
case EventTimeTumblingGroupWindow(_, _, size) =>
case TumblingGroupWindow(_, _, size) =>
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
throw new UnsupportedOperationException(
"Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
case SlidingGroupWindow(_, timeField, size, slide)
if isProctimeAttribute(timeField) && isTimeIntervalLiteral(slide) =>
stream.window(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide)))
case ProcessingTimeSlidingGroupWindow(_, size, slide) =>
stream.countWindow(asCount(size), asCount(slide))
case SlidingGroupWindow(_, timeField, size, slide)
if isProctimeAttribute(timeField) && isRowCountLiteral(size) =>
stream.countWindow(toLong(size), toLong(slide))
case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide)))
case SlidingGroupWindow(_, timeField, size, slide)
if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=>
stream.window(SlidingEventTimeWindows.of(toTime(size), toTime(slide)))
case EventTimeSlidingGroupWindow(_, _, size, slide) =>
case SlidingGroupWindow(_, _, size, slide) =>
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
throw new UnsupportedOperationException(
"Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap: Expression) =>
stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap)))
case SessionGroupWindow(_, timeField, gap)
if isProctimeAttribute(timeField) =>
stream.window(ProcessingTimeSessionWindows.withGap(toTime(gap)))
case EventTimeSessionGroupWindow(_, _, gap) =>
stream.window(EventTimeSessionWindows.withGap(asTime(gap)))
case SessionGroupWindow(_, timeField, gap)
if isRowtimeAttribute(timeField) =>
stream.window(EventTimeSessionWindows.withGap(toTime(gap)))
}
private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row])
: AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match {
case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) =>
stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size)))
case TumblingGroupWindow(_, timeField, size)
if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) =>
stream.windowAll(TumblingProcessingTimeWindows.of(toTime(size)))
case ProcessingTimeTumblingGroupWindow(_, size) =>
stream.countWindowAll(asCount(size))
case TumblingGroupWindow(_, timeField, size)
if isProctimeAttribute(timeField) && isRowCountLiteral(size)=>
stream.countWindowAll(toLong(size))
case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
stream.windowAll(TumblingEventTimeWindows.of(asTime(size)))
case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) =>
stream.windowAll(TumblingEventTimeWindows.of(toTime(size)))
case EventTimeTumblingGroupWindow(_, _, size) =>
case TumblingGroupWindow(_, _, size) =>
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
throw new UnsupportedOperationException(
"Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) =>
stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide)))
case SlidingGroupWindow(_, timeField, size, slide)
if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) =>
stream.windowAll(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide)))
case ProcessingTimeSlidingGroupWindow(_, size, slide) =>
stream.countWindowAll(asCount(size), asCount(slide))
case SlidingGroupWindow(_, timeField, size, slide)
if isProctimeAttribute(timeField) && isRowCountLiteral(size)=>
stream.countWindowAll(toLong(size), toLong(slide))
case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) =>
stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide)))
case SlidingGroupWindow(_, timeField, size, slide)
if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=>
stream.windowAll(SlidingEventTimeWindows.of(toTime(size), toTime(slide)))
case EventTimeSlidingGroupWindow(_, _, size, slide) =>
case SlidingGroupWindow(_, _, size, slide) =>
// TODO: EventTimeTumblingGroupWindow should sort the stream on event time
// before applying the windowing logic. Otherwise, this would be the same as a
// ProcessingTimeTumblingGroupWindow
throw new UnsupportedOperationException(
"Event-time grouping windows on row intervals are currently not supported.")
case ProcessingTimeSessionGroupWindow(_, gap) =>
stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap)))
case SessionGroupWindow(_, timeField, gap)
if isProctimeAttribute(timeField) && isTimeIntervalLiteral(gap) =>
stream.windowAll(ProcessingTimeSessionWindows.withGap(toTime(gap)))
case EventTimeSessionGroupWindow(_, _, gap) =>
stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap)))
case SessionGroupWindow(_, timeField, gap)
if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(gap) =>
stream.windowAll(EventTimeSessionWindows.withGap(toTime(gap)))
}
def asTime(expr: Expression): Time = expr match {
case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value)
case _ => throw new IllegalArgumentException()
}
def asCount(expr: Expression): Long = expr match {
case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value
case _ => throw new IllegalArgumentException()
}
}
......@@ -27,9 +27,9 @@ import org.apache.calcite.rex.RexProgram
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.nodes.CommonCalc
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.types.Row
/**
......@@ -40,17 +40,25 @@ class DataStreamCalc(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
rowRelDataType: RelDataType,
inputSchema: RowSchema,
schema: RowSchema,
calcProgram: RexProgram,
ruleDescription: String)
extends Calc(cluster, traitSet, input, calcProgram)
with CommonCalc
with DataStreamRel {
override def deriveRowType(): RelDataType = rowRelDataType
override def deriveRowType(): RelDataType = schema.logicalType
override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = {
new DataStreamCalc(cluster, traitSet, child, getRowType, program, ruleDescription)
new DataStreamCalc(
cluster,
traitSet,
child,
inputSchema,
schema,
program,
ruleDescription)
}
override def toString: String = calcToString(calcProgram, getExpressionString)
......@@ -85,8 +93,8 @@ class DataStreamCalc(
val body = functionBody(
generator,
inputDataStream.getType,
getRowType,
inputSchema,
schema,
calcProgram,
config)
......@@ -94,7 +102,7 @@ class DataStreamCalc(
ruleDescription,
classOf[FlatMapFunction[Row, Row]],
body,
FlinkTypeFactory.toInternalRowTypeInfo(getRowType))
schema.physicalTypeInfo)
val mapFunc = calcMapFunction(genFunction)
inputDataStream.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString))
......
......@@ -18,7 +18,6 @@
package org.apache.flink.table.plan.nodes.datastream
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.sql.SemiJoinType
......@@ -28,6 +27,7 @@ import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.nodes.CommonCorrelate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.types.Row
/**
......@@ -36,28 +36,30 @@ import org.apache.flink.types.Row
class DataStreamCorrelate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputSchema: RowSchema,
inputNode: RelNode,
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
relRowType: RelDataType,
joinRowType: RelDataType,
schema: RowSchema,
joinSchema: RowSchema,
joinType: SemiJoinType,
ruleDescription: String)
extends SingleRel(cluster, traitSet, inputNode)
with CommonCorrelate
with DataStreamRel {
override def deriveRowType() = relRowType
override def deriveRowType() = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamCorrelate(
cluster,
traitSet,
inputSchema,
inputs.get(0),
scan,
condition,
relRowType,
joinRowType,
schema,
joinSchema,
joinType,
ruleDescription)
}
......@@ -74,7 +76,7 @@ class DataStreamCorrelate(
super.explainTerms(pw)
.item("invocation", scan.getCall)
.item("function", sqlFunction.getTableFunction.getClass.getCanonicalName)
.item("rowType", relRowType)
.item("rowType", schema.logicalType)
.item("joinType", joinType)
.itemIf("condition", condition.orNull, condition.isDefined)
}
......@@ -94,16 +96,16 @@ class DataStreamCorrelate(
val mapFunc = correlateMapFunction(
config,
inputDS.getType,
inputSchema,
udtfTypeInfo,
getRowType,
schema,
joinType,
rexCall,
condition,
Some(pojoFieldMapping),
ruleDescription)
inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType))
inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, schema.logicalType))
}
}
......@@ -20,36 +20,34 @@ package org.apache.flink.table.plan.nodes.datastream
import java.util.{List => JList}
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.{StreamTableEnvironment, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.nodes.OverAggregate
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.types.Row
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
class DataStreamOverAggregate(
logicWindow: Window,
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputNode: RelNode,
rowRelDataType: RelDataType,
inputType: RelDataType)
schema: RowSchema,
inputSchema: RowSchema)
extends SingleRel(cluster, traitSet, inputNode)
with OverAggregate
with DataStreamRel {
override def deriveRowType(): RelDataType = rowRelDataType
override def deriveRowType(): RelDataType = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: JList[RelNode]): RelNode = {
new DataStreamOverAggregate(
......@@ -57,8 +55,8 @@ class DataStreamOverAggregate(
cluster,
traitSet,
inputs.get(0),
getRowType,
inputType)
schema,
inputSchema)
}
override def toString: String = {
......@@ -72,14 +70,16 @@ class DataStreamOverAggregate(
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
super.explainTerms(pw)
.itemIf("partitionBy", partitionToString(inputType, partitionKeys), partitionKeys.nonEmpty)
.item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations))
.itemIf("rows", windowRange(logicWindow, overWindow, getInput), overWindow.isRows)
.itemIf("range", windowRange(logicWindow, overWindow, getInput), !overWindow.isRows)
.itemIf("partitionBy",
partitionToString(schema.logicalType, partitionKeys), partitionKeys.nonEmpty)
.item("orderBy",
orderingToString(schema.logicalType, overWindow.orderKeys.getFieldCollations))
.itemIf("rows", windowRange(logicWindow, overWindow, inputNode), overWindow.isRows)
.itemIf("range", windowRange(logicWindow, overWindow, inputNode), !overWindow.isRows)
.item(
"select", aggregationToString(
inputType,
getRowType,
inputSchema.logicalType,
schema.logicalType,
namedAggregates))
}
......@@ -111,13 +111,13 @@ class DataStreamOverAggregate(
false,
inputDS.getType)
val timeType = inputType
val timeType = schema.logicalType
.getFieldList
.get(orderKey.getFieldIndex)
.getValue
.getType
timeType match {
case _: ProcTimeType =>
case _ if FlinkTypeFactory.isProctimeIndicatorType(timeType) =>
// proc-time OVER window
if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
// unbounded OVER window
......@@ -140,7 +140,8 @@ class DataStreamOverAggregate(
throw new TableException(
"OVER RANGE FOLLOWING windows are not supported yet.")
}
case _: RowTimeType =>
case _ if FlinkTypeFactory.isRowtimeIndicatorType(timeType) =>
// row-time OVER window
if (overWindow.lowerBound.isPreceding &&
overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
......@@ -158,17 +159,16 @@ class DataStreamOverAggregate(
inputDS,
isRowTimeType = true,
isRowsClause = overWindow.isRows
)
)
} else {
throw new TableException(
"OVER RANGE FOLLOWING windows are not supported yet.")
}
case _ =>
throw new TableException(
"Unsupported time type {$timeType}. " +
"OVER windows do only support RowTimeType and ProcTimeType.")
s"OVER windows can only be applied on time attributes.")
}
}
def createUnboundedAndCurrentRowOverWindow(
......@@ -178,16 +178,20 @@ class DataStreamOverAggregate(
isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
val partitionKeys: Array[Int] = overWindow.keys.toArray.map(schema.mapIndex)
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates.map {
namedAggregate =>
new CalcitePair[AggregateCall, String](
schema.mapAggregateCall(namedAggregate.left),
namedAggregate.right)
}
val processFunction = AggregateUtil.createUnboundedOverProcessFunction(
generator,
namedAggregates,
inputType,
inputSchema.physicalType,
inputSchema.physicalTypeInfo,
inputSchema.physicalFieldTypeInfo,
isRowTimeType,
partitionKeys.nonEmpty,
isRowsClause)
......@@ -198,7 +202,7 @@ class DataStreamOverAggregate(
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.returns(schema.physicalTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
......@@ -207,13 +211,13 @@ class DataStreamOverAggregate(
if (isRowTimeType) {
inputDS.keyBy(new NullByteKeySelector[Row])
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.returns(schema.physicalTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
} else {
inputDS
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.returns(schema.physicalTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
......@@ -228,19 +232,26 @@ class DataStreamOverAggregate(
isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
val partitionKeys: Array[Int] = overWindow.keys.toArray.map(schema.mapIndex)
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates.map {
namedAggregate =>
new CalcitePair[AggregateCall, String](
schema.mapAggregateCall(namedAggregate.left),
namedAggregate.right)
}
val precedingOffset =
getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0)
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
getLowerBoundary(
logicWindow,
overWindow,
input) + (if (isRowsClause) 1 else 0)
val processFunction = AggregateUtil.createBoundedOverProcessFunction(
generator,
namedAggregates,
inputType,
inputSchema.physicalType,
inputSchema.physicalTypeInfo,
inputSchema.physicalFieldTypeInfo,
precedingOffset,
isRowsClause,
isRowTimeType
......@@ -251,7 +262,7 @@ class DataStreamOverAggregate(
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
.returns(rowTypeInfo)
.returns(schema.physicalTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
......@@ -260,7 +271,7 @@ class DataStreamOverAggregate(
inputDS
.keyBy(new NullByteKeySelector[Row])
.process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.returns(schema.physicalTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
}
......@@ -282,17 +293,18 @@ class DataStreamOverAggregate(
s"over: (${
if (!partitionKeys.isEmpty) {
s"PARTITION BY: ${partitionToString(inputType, partitionKeys)}, "
s"PARTITION BY: ${partitionToString(inputSchema.logicalType, partitionKeys)}, "
} else {
""
}
}ORDER BY: ${orderingToString(inputType, overWindow.orderKeys.getFieldCollations)}, " +
}ORDER BY: ${orderingToString(inputSchema.logicalType,
overWindow.orderKeys.getFieldCollations)}, " +
s"${if (overWindow.isRows) "ROWS" else "RANGE"}" +
s"${windowRange(logicWindow, overWindow, getInput)}, " +
s"${windowRange(logicWindow, overWindow, inputNode.asInstanceOf[DataStreamRel])}, " +
s"select: (${
aggregationToString(
inputType,
getRowType,
inputSchema.logicalType,
schema.logicalType,
namedAggregates)
}))"
}
......
......@@ -34,4 +34,3 @@ trait DataStreamRel extends FlinkRelNode {
def translateToPlan(tableEnv: StreamTableEnvironment) : DataStream[Row]
}
......@@ -24,7 +24,7 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.TableScan
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.plan.schema.DataStreamTable
import org.apache.flink.table.plan.schema.{DataStreamTable, RowSchema}
import org.apache.flink.types.Row
/**
......@@ -36,27 +36,27 @@ class DataStreamScan(
cluster: RelOptCluster,
traitSet: RelTraitSet,
table: RelOptTable,
rowRelDataType: RelDataType)
schema: RowSchema)
extends TableScan(cluster, traitSet, table)
with StreamScan {
val dataStreamTable: DataStreamTable[Any] = getTable.unwrap(classOf[DataStreamTable[Any]])
override def deriveRowType(): RelDataType = rowRelDataType
override def deriveRowType(): RelDataType = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamScan(
cluster,
traitSet,
getTable,
getRowType
schema
)
}
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
val config = tableEnv.getConfig
val inputDataStream: DataStream[Any] = dataStreamTable.dataStream
convertToInternalRow(inputDataStream, dataStreamTable, config)
convertToInternalRow(schema, inputDataStream, dataStreamTable, config)
}
}
......@@ -19,14 +19,12 @@
package org.apache.flink.table.plan.nodes.datastream
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.types.Row
import scala.collection.JavaConverters._
/**
* Flink RelNode which matches along with Union.
*
......@@ -36,11 +34,11 @@ class DataStreamUnion(
traitSet: RelTraitSet,
leftNode: RelNode,
rightNode: RelNode,
rowRelDataType: RelDataType)
schema: RowSchema)
extends BiRel(cluster, traitSet, leftNode, rightNode)
with DataStreamRel {
override def deriveRowType() = rowRelDataType
override def deriveRowType() = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamUnion(
......@@ -48,7 +46,7 @@ class DataStreamUnion(
traitSet,
inputs.get(0),
inputs.get(1),
getRowType
schema
)
}
......@@ -57,7 +55,7 @@ class DataStreamUnion(
}
override def toString = {
s"Union All(union: (${getRowType.getFieldNames.asScala.toList.mkString(", ")}))"
s"Union All(union: (${schema.logicalFieldNames.mkString(", ")}))"
}
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
......@@ -68,6 +66,6 @@ class DataStreamUnion(
}
private def unionSelectionToString: String = {
getRowType.getFieldNames.asScala.toList.mkString(", ")
schema.logicalFieldNames.mkString(", ")
}
}
......@@ -21,13 +21,12 @@ package org.apache.flink.table.plan.nodes.datastream
import com.google.common.collect.ImmutableList
import org.apache.calcite.plan._
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Values
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.io.ValuesInputFormat
import org.apache.flink.types.Row
......@@ -39,19 +38,19 @@ import scala.collection.JavaConverters._
class DataStreamValues(
cluster: RelOptCluster,
traitSet: RelTraitSet,
rowRelDataType: RelDataType,
schema: RowSchema,
tuples: ImmutableList[ImmutableList[RexLiteral]],
ruleDescription: String)
extends Values(cluster, rowRelDataType, tuples, traitSet)
extends Values(cluster, schema.logicalType, tuples, traitSet)
with DataStreamRel {
override def deriveRowType() = rowRelDataType
override def deriveRowType() = schema.logicalType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamValues(
cluster,
traitSet,
getRowType,
schema,
getTuples,
ruleDescription
)
......@@ -61,15 +60,13 @@ class DataStreamValues(
val config = tableEnv.getConfig
val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
val generator = new CodeGenerator(config)
// generate code for every record
val generatedRecords = getTuples.asScala.map { r =>
generator.generateResultExpression(
returnType,
getRowType.getFieldNames.asScala,
schema.physicalTypeInfo,
schema.physicalFieldNames,
r.asScala)
}
......@@ -77,14 +74,14 @@ class DataStreamValues(
val generatedFunction = generator.generateValuesInputFormat(
ruleDescription,
generatedRecords.map(_.code),
returnType)
schema.physicalTypeInfo)
val inputFormat = new ValuesInputFormat[Row](
generatedFunction.name,
generatedFunction.code,
generatedFunction.returnType)
tableEnv.execEnv.createInput(inputFormat, returnType)
tableEnv.execEnv.createInput(inputFormat, schema.physicalTypeInfo)
}
}
......@@ -18,42 +18,46 @@
package org.apache.flink.table.plan.nodes.datastream
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.nodes.CommonScan
import org.apache.flink.table.plan.schema.FlinkTable
import org.apache.flink.table.plan.schema.{FlinkTable, RowSchema}
import org.apache.flink.table.runtime.MapRunner
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
trait StreamScan extends CommonScan with DataStreamRel {
protected def convertToInternalRow(
schema: RowSchema,
input: DataStream[Any],
flinkTable: FlinkTable[_],
config: TableConfig)
: DataStream[Row] = {
val inputType = input.getType
val internalType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
// conversion
if (needsConversion(inputType, internalType)) {
if (needsConversion(input.getType, schema.physicalTypeInfo)) {
val mapFunc = getConversionMapper(
val function = generatedConversionFunction(
config,
inputType,
internalType,
classOf[MapFunction[Any, Row]],
input.getType,
schema.physicalTypeInfo,
"DataStreamSourceConversion",
getRowType.getFieldNames,
schema.physicalFieldNames,
Some(flinkTable.fieldIndexes))
val opName = s"from: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})"
val runner = new MapRunner[Any, Row](
function.name,
function.code,
function.returnType)
val opName = s"from: (${schema.logicalFieldNames.mkString(", ")})"
input.map(mapFunc).name(opName)
// TODO we need a ProcessFunction here
input.map(runner).name(opName)
}
// no conversion necessary, forward
else {
......
......@@ -22,10 +22,11 @@ import org.apache.calcite.plan._
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.nodes.PhysicalTableSourceScan
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.{StreamTableSource, TableSource}
import org.apache.flink.table.plan.schema.{RowSchema, TableSourceTable}
import org.apache.flink.table.sources.{DefinedTimeAttributes, StreamTableSource, TableSource}
import org.apache.flink.types.Row
/** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */
......@@ -37,7 +38,50 @@ class StreamTableSourceScan(
extends PhysicalTableSourceScan(cluster, traitSet, table, tableSource)
with StreamScan {
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
def removeIndex[T](idx: Int, l: List[T]): List[T] = {
if (l.size < idx) {
l
} else {
l.take(idx) ++ l.drop(idx + 1)
}
}
var fieldNames = TableEnvironment.getFieldNames(tableSource).toList
var fieldTypes = TableEnvironment.getFieldTypes(tableSource.getReturnType).toList
val rowtime = tableSource match {
case timeSource: DefinedTimeAttributes if timeSource.getRowtimeAttribute != null =>
val rowtimeAttribute = timeSource.getRowtimeAttribute
// remove physical field if it is overwritten by time attribute
fieldNames = removeIndex(rowtimeAttribute.f0, fieldNames)
fieldTypes = removeIndex(rowtimeAttribute.f0, fieldTypes)
Some((rowtimeAttribute.f0, rowtimeAttribute.f1))
case _ =>
None
}
val proctime = tableSource match {
case timeSource: DefinedTimeAttributes if timeSource.getProctimeAttribute != null =>
val proctimeAttribute = timeSource.getProctimeAttribute
// remove physical field if it is overwritten by time attribute
fieldNames = removeIndex(proctimeAttribute.f0, fieldNames)
fieldTypes = removeIndex(proctimeAttribute.f0, fieldTypes)
Some((proctimeAttribute.f0, proctimeAttribute.f1))
case _ =>
None
}
flinkTypeFactory.buildLogicalRowType(
fieldNames,
fieldTypes,
rowtime,
proctime)
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val rowCnt = metadata.getRowCount(this)
planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType))
}
......@@ -67,6 +111,10 @@ class StreamTableSourceScan(
override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = {
val config = tableEnv.getConfig
val inputDataStream = tableSource.getDataStream(tableEnv.execEnv).asInstanceOf[DataStream[Any]]
convertToInternalRow(inputDataStream, new TableSourceTable(tableSource), config)
convertToInternalRow(
new RowSchema(getRowType),
inputDataStream,
new TableSourceTable(tableSource),
config)
}
}
......@@ -45,7 +45,7 @@ class FlinkLogicalOverWindow(
traitSet,
inputs.get(0),
windowConstants,
rowType,
getRowType,
windowGroups)
}
}
......
......@@ -47,9 +47,11 @@ class FlinkLogicalTableSourceScan(
override def deriveRowType(): RelDataType = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildRowDataType(
flinkTypeFactory.buildLogicalRowType(
TableEnvironment.getFieldNames(tableSource),
TableEnvironment.getFieldTypes(tableSource.getReturnType))
TableEnvironment.getFieldTypes(tableSource.getReturnType),
None,
None)
}
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
......
......@@ -53,8 +53,8 @@ class WindowStartEndPropertiesRule
transformed.push(LogicalWindowAggregate.create(
agg.getWindow,
Seq(
NamedWindowProperty("w$start", WindowStart(agg.getWindow.alias)),
NamedWindowProperty("w$end", WindowEnd(agg.getWindow.alias))
NamedWindowProperty("w$start", WindowStart(agg.getWindow.aliasAttribute)),
NamedWindowProperty("w$end", WindowEnd(agg.getWindow.aliasAttribute))
), agg)
)
......
......@@ -25,6 +25,7 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate
import org.apache.flink.table.plan.schema.RowSchema
import scala.collection.JavaConversions._
......@@ -65,8 +66,8 @@ class DataStreamAggregateRule
traitSet,
convInput,
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
new RowSchema(rel.getRowType),
new RowSchema(agg.getInput.getRowType),
agg.getGroupSet.toArray)
}
}
......
......@@ -24,6 +24,7 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamCalc
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc
import org.apache.flink.table.plan.schema.RowSchema
class DataStreamCalcRule
extends ConverterRule(
......@@ -42,7 +43,8 @@ class DataStreamCalcRule
rel.getCluster,
traitSet,
convInput,
rel.getRowType,
new RowSchema(convInput.getRowType),
new RowSchema(rel.getRowType),
calc.getProgram,
description)
}
......
......@@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexNode
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamCorrelate
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.plan.schema.RowSchema
class DataStreamCorrelateRule
extends ConverterRule(
......@@ -68,11 +69,12 @@ class DataStreamCorrelateRule
new DataStreamCorrelate(
rel.getCluster,
traitSet,
new RowSchema(convInput.getRowType),
convInput,
scan,
condition,
rel.getRowType,
join.getRowType,
new RowSchema(rel.getRowType),
new RowSchema(join.getRowType),
join.getJoinType,
description)
}
......
......@@ -18,15 +18,15 @@
package org.apache.flink.table.plan.rules.datastream
import java.math.BigDecimal
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.{RexBuilder, RexCall, RexLiteral, RexNode}
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.flink.table.api.{TableException, Window}
import org.apache.flink.table.api.scala.{Session, Slide, Tumble}
import org.apache.flink.table.expressions.Literal
import org.apache.flink.table.functions.TimeModeTypes
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.expressions.{Literal, UnresolvedFieldReference}
import org.apache.flink.table.plan.rules.common.LogicalWindowAggregateRule
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
......@@ -49,16 +49,12 @@ class DataStreamLogicalWindowAggregateRule
val timeType = windowExpression.operands.get(0).getType
timeType match {
case TimeModeTypes.ROWTIME =>
rexBuilder.makeAbstractCast(
TimeModeTypes.ROWTIME,
rexBuilder.makeLiteral(0L, TimeModeTypes.ROWTIME, true))
case TimeModeTypes.PROCTIME =>
rexBuilder.makeAbstractCast(
TimeModeTypes.PROCTIME,
rexBuilder.makeLiteral(0L, TimeModeTypes.PROCTIME, true))
case _ if FlinkTypeFactory.isTimeIndicatorType(timeType) =>
rexBuilder.makeLiteral(0L, timeType, true)
case _ =>
throw TableException(s"""Unexpected time type $timeType encountered""")
throw TableException(s"""Time attribute expected but $timeType encountered.""")
}
}
......@@ -68,41 +64,41 @@ class DataStreamLogicalWindowAggregateRule
def getOperandAsLong(call: RexCall, idx: Int): Long =
call.getOperands.get(idx) match {
case v : RexLiteral => v.getValue.asInstanceOf[BigDecimal].longValue()
case _ => throw new TableException("Only constant window descriptors are supported")
case v: RexLiteral => v.getValue.asInstanceOf[JBigDecimal].longValue()
case _ => throw new TableException("Only constant window descriptors are supported.")
}
def getOperandAsTimeIndicator(call: RexCall, idx: Int): String =
call.getOperands.get(idx) match {
case v: RexInputRef if FlinkTypeFactory.isTimeIndicatorType(v.getType) =>
rowType.getFieldList.get(v.getIndex).getName
case _ =>
throw new TableException("Window can only be defined over a time attribute column.")
}
windowExpr.getOperator match {
case SqlStdOperatorTable.TUMBLE =>
val time = getOperandAsTimeIndicator(windowExpr, 0)
val interval = getOperandAsLong(windowExpr, 1)
val w = Tumble.over(Literal(interval, TimeIntervalTypeInfo.INTERVAL_MILLIS))
val window = windowExpr.getType match {
case TimeModeTypes.PROCTIME => w
case TimeModeTypes.ROWTIME => w.on("rowtime")
}
window.as("w$")
w.on(UnresolvedFieldReference(time)).as("w$")
case SqlStdOperatorTable.HOP =>
val time = getOperandAsTimeIndicator(windowExpr, 0)
val (slide, size) = (getOperandAsLong(windowExpr, 1), getOperandAsLong(windowExpr, 2))
val w = Slide
.over(Literal(size, TimeIntervalTypeInfo.INTERVAL_MILLIS))
.every(Literal(slide, TimeIntervalTypeInfo.INTERVAL_MILLIS))
val window = windowExpr.getType match {
case TimeModeTypes.PROCTIME => w
case TimeModeTypes.ROWTIME => w.on("rowtime")
}
window.as("w$")
w.on(UnresolvedFieldReference(time)).as("w$")
case SqlStdOperatorTable.SESSION =>
val time = getOperandAsTimeIndicator(windowExpr, 0)
val gap = getOperandAsLong(windowExpr, 1)
val w = Session.withGap(Literal(gap, TimeIntervalTypeInfo.INTERVAL_MILLIS))
val window = windowExpr.getType match {
case TimeModeTypes.PROCTIME => w
case TimeModeTypes.ROWTIME => w.on("rowtime")
}
window.as("w$")
w.on(UnresolvedFieldReference(time)).as("w$")
}
}
}
......
......@@ -25,6 +25,7 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamOverAggregate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalOverWindow
import org.apache.flink.table.plan.schema.RowSchema
class DataStreamOverAggregateRule
extends ConverterRule(
......@@ -46,8 +47,8 @@ class DataStreamOverAggregateRule
rel.getCluster,
traitSet,
convertInput,
rel.getRowType,
inputRowType)
new RowSchema(rel.getRowType),
new RowSchema(inputRowType))
}
}
......
......@@ -23,7 +23,7 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamScan
import org.apache.flink.table.plan.schema.DataStreamTable
import org.apache.flink.table.plan.schema.{DataStreamTable, RowSchema}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalNativeTableScan
class DataStreamScanRule
......@@ -53,7 +53,7 @@ class DataStreamScanRule
rel.getCluster,
traitSet,
scan.getTable,
rel.getRowType
new RowSchema(rel.getRowType)
)
}
}
......
......@@ -24,6 +24,7 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamUnion
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalUnion
import org.apache.flink.table.plan.schema.RowSchema
class DataStreamUnionRule
extends ConverterRule(
......@@ -44,7 +45,7 @@ class DataStreamUnionRule
traitSet,
convLeft,
convRight,
rel.getRowType)
new RowSchema(rel.getRowType))
}
}
......
......@@ -24,6 +24,7 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.datastream.DataStreamValues
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalValues
import org.apache.flink.table.plan.schema.RowSchema
class DataStreamValuesRule
extends ConverterRule(
......@@ -40,7 +41,7 @@ class DataStreamValuesRule
new DataStreamValues(
rel.getCluster,
traitSet,
rel.getRowType,
new RowSchema(rel.getRowType),
values.getTuples,
description)
}
......
......@@ -18,13 +18,27 @@
package org.apache.flink.table.plan.schema
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.stats.FlinkStatistic
class DataStreamTable[T](
val dataStream: DataStream[T],
override val fieldIndexes: Array[Int],
override val fieldNames: Array[String],
val rowtime: Option[(Int, String)],
val proctime: Option[(Int, String)],
override val statistic: FlinkStatistic = FlinkStatistic.UNKNOWN)
extends FlinkTable[T](dataStream.getType, fieldIndexes, fieldNames, statistic) {
override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = {
val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildLogicalRowType(
fieldNames,
fieldTypes,
rowtime,
proctime)
}
}
......@@ -48,10 +48,11 @@ abstract class FlinkTable[T](
val fieldTypes: Array[TypeInformation[_]] =
typeInfo match {
case cType: CompositeType[_] =>
if (fieldNames.length != cType.getArity) {
// it is ok to leave out fields
if (fieldNames.length > cType.getArity) {
throw new TableException(
s"Arity of type (" + cType.getFieldNames.deep + ") " +
"not equal to number of field names " + fieldNames.deep + ".")
"must not be greater than number of field names " + fieldNames.deep + ".")
}
fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]])
case aType: AtomicType[_] =>
......@@ -64,7 +65,7 @@ abstract class FlinkTable[T](
override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = {
val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildRowDataType(fieldNames, fieldTypes)
flinkTypeFactory.buildLogicalRowType(fieldNames, fieldTypes, None, None)
}
/**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.schema
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField, RelRecordType}
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
/**
* Schema that describes both a logical and physical row.
*/
class RowSchema(private val logicalRowType: RelDataType) {
private lazy val physicalRowFields: Seq[RelDataTypeField] = logicalRowType.getFieldList filter {
field => !FlinkTypeFactory.isTimeIndicatorType(field.getType)
}
private lazy val physicalRowType: RelDataType = new RelRecordType(physicalRowFields)
private lazy val physicalRowFieldTypes: Seq[TypeInformation[_]] = physicalRowFields map { f =>
FlinkTypeFactory.toTypeInfo(f.getType)
}
private lazy val physicalRowFieldNames: Seq[String] = physicalRowFields.map(_.getName)
private lazy val physicalRowTypeInfo: TypeInformation[Row] = new RowTypeInfo(
physicalRowFieldTypes.toArray, physicalRowFieldNames.toArray)
private lazy val indexMapping: Array[Int] = generateIndexMapping
private lazy val inputRefUpdater = new RexInputRefUpdater()
private def generateIndexMapping: Array[Int] = {
val mapping = new Array[Int](logicalRowType.getFieldCount)
var countTimeIndicators = 0
var i = 0
while (i < logicalRowType.getFieldCount) {
val t = logicalRowType.getFieldList.get(i).getType
if (FlinkTypeFactory.isTimeIndicatorType(t)) {
countTimeIndicators += 1
// no mapping
mapping(i) = -1
} else {
mapping(i) = i - countTimeIndicators
}
i += 1
}
mapping
}
private class RexInputRefUpdater extends RexShuttle {
override def visitInputRef(inputRef: RexInputRef): RexNode = {
new RexInputRef(mapIndex(inputRef.getIndex), inputRef.getType)
}
}
/**
* Returns the arity of the logical record.
*/
def logicalArity: Int = logicalRowType.getFieldCount
/**
* Returns the arity of the physical record.
*/
def physicalArity: Int = physicalTypeInfo.getArity
/**
* Returns a logical [[RelDataType]] including logical fields (i.e. time indicators).
*/
def logicalType: RelDataType = logicalRowType
/**
* Returns a physical [[RelDataType]] with no logical fields (i.e. time indicators).
*/
def physicalType: RelDataType = physicalRowType
/**
* Returns a physical [[TypeInformation]] of row with no logical fields (i.e. time indicators).
*/
def physicalTypeInfo: TypeInformation[Row] = physicalRowTypeInfo
/**
* Returns [[TypeInformation]] of the row's fields with no logical fields (i.e. time indicators).
*/
def physicalFieldTypeInfo: Seq[TypeInformation[_]] = physicalRowFieldTypes
/**
* Returns the logical fields names including logical fields (i.e. time indicators).
*/
def logicalFieldNames: Seq[String] = logicalRowType.getFieldNames
/**
* Returns the physical fields names with no logical fields (i.e. time indicators).
*/
def physicalFieldNames: Seq[String] = physicalRowFieldNames
/**
* Converts logical indices to physical indices based on this schema.
*/
def mapIndex(logicalIndex: Int): Int = {
val mappedIndex = indexMapping(logicalIndex)
if (mappedIndex < 0) {
throw new TableException("Invalid access to a logical field.")
} else {
mappedIndex
}
}
/**
* Converts logical indices of a aggregate call to physical ones.
*/
def mapAggregateCall(logicalAggCall: AggregateCall): AggregateCall = {
logicalAggCall.copy(
logicalAggCall.getArgList.map(mapIndex(_).asInstanceOf[Integer]),
if (logicalAggCall.filterArg < 0) {
logicalAggCall.filterArg
} else {
mapIndex(logicalAggCall.filterArg)
}
)
}
/**
* Converts logical field references of a [[RexNode]] to physical ones.
*/
def mapRexNode(logicalRexNode: RexNode): RexNode = logicalRexNode.accept(inputRefUpdater)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.schema
import org.apache.calcite.rel.`type`.RelDataTypeSystem
import org.apache.calcite.sql.`type`.BasicSqlType
/**
* Creates a time indicator type for event-time or processing-time, but with similar properties
* as a basic SQL type.
*/
class TimeIndicatorRelDataType(
typeSystem: RelDataTypeSystem,
originalType: BasicSqlType,
val isEventTime: Boolean)
extends BasicSqlType(
typeSystem,
originalType.getSqlTypeName,
originalType.getPrecision) {
override def equals(other: Any): Boolean = other match {
case that: TimeIndicatorRelDataType =>
super.equals(that) &&
isEventTime == that.isEventTime
case that: BasicSqlType =>
super.equals(that)
case _ => false
}
override def hashCode(): Int = {
super.hashCode() + 42 // we change the hash code to differentiate from regular timestamps
}
}
......@@ -35,7 +35,7 @@ class MapRunner[IN, OUT](
val LOG = LoggerFactory.getLogger(this.getClass)
private var function: MapFunction[IN, OUT] = null
private var function: MapFunction[IN, OUT] = _
override def open(parameters: Configuration): Unit = {
LOG.debug(s"Compiling MapFunction: $name \n\n Code:\n$code")
......
......@@ -21,7 +21,7 @@ import java.util.{List => JList, ArrayList => JArrayList}
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
import org.apache.flink.api.java.typeutils.ListTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
......@@ -39,8 +39,8 @@ import org.slf4j.LoggerFactory
*/
class RowTimeBoundedRangeOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo,
inputRowType: RowTypeInfo,
aggregationStateType: TypeInformation[Row],
inputRowType: TypeInformation[Row],
precedingOffset: Long)
extends ProcessFunction[Row, Row]
with Compiler[GeneratedAggregations] {
......
......@@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory
class RowTimeBoundedRowsOver(
genAggregations: GeneratedAggregationsFunction,
aggregationStateType: RowTypeInfo,
inputRowType: RowTypeInfo,
inputRowType: TypeInformation[Row],
precedingOffset: Long)
extends ProcessFunction[Row, Row]
with Compiler[GeneratedAggregations] {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.sources
import org.apache.flink.api.java.tuple.Tuple2
/**
* Defines logical time attributes for a [[TableSource]]. Time attributes can be used for
* indicating, accessing, and working with Flink's event-time or processing-time. A
* [[TableSource]] that implements this interface can define names and positions of rowtime
* and proctime attributes in the rows it produces.
*/
trait DefinedTimeAttributes {
/**
* Defines a name and position (starting at 0) of rowtime attribute that represents Flink's
* event-time. Null if no rowtime should be available. If the position is within the arity of
* the result row, the logical attribute will overwrite the physical attribute. If the position
* is higher than the result row, the time attribute will be appended logically.
*/
def getRowtimeAttribute: Tuple2[Int, String]
/**
* Defines a name and position (starting at 0) of proctime attribute that represents Flink's
* processing-time. Null if no proctime should be available. If the position is within the arity
* of the result row, the logical attribute will overwrite the physical attribute. If the
* position is higher than the result row, the time attribute will be appended logically.
*/
def getProctimeAttribute: Tuple2[Int, String]
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.typeutils
import java.sql.Timestamp
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo
import org.apache.flink.api.common.typeutils.TypeComparator
import org.apache.flink.api.common.typeutils.base.{SqlTimestampComparator, SqlTimestampSerializer}
/**
* Type information for indicating event or processing time. However, it behaves like a
* regular SQL timestamp.
*/
class TimeIndicatorTypeInfo(val isEventTime: Boolean)
extends SqlTimeTypeInfo[Timestamp](
classOf[Timestamp],
SqlTimestampSerializer.INSTANCE,
classOf[SqlTimestampComparator].asInstanceOf[Class[TypeComparator[Timestamp]]]) {
override def toString: String = s"TimeIndicatorTypeInfo"
}
object TimeIndicatorTypeInfo {
val ROWTIME_INDICATOR = new TimeIndicatorTypeInfo(true)
val PROCTIME_INDICATOR = new TimeIndicatorTypeInfo(false)
}
......@@ -17,7 +17,7 @@
*/
package org.apache.flink.table.typeutils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, INT_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.table.validate._
......@@ -29,6 +29,7 @@ object TypeCheckUtils {
* SQL type but NOT vice versa.
*/
def isAdvanced(dataType: TypeInformation[_]): Boolean = dataType match {
case _: TimeIndicatorTypeInfo => false
case _: BasicTypeInfo[_] => false
case _: SqlTimeTypeInfo[_] => false
case _: TimeIntervalTypeInfo[_] => false
......@@ -64,6 +65,8 @@ object TypeCheckUtils {
def isInteger(dataType: TypeInformation[_]): Boolean = dataType == INT_TYPE_INFO
def isLong(dataType: TypeInformation[_]): Boolean = dataType == LONG_TYPE_INFO
def isArray(dataType: TypeInformation[_]): Boolean = dataType match {
case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => true
case _ => false
......
......@@ -24,7 +24,7 @@ import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.{AggSqlFunction, ScalarSqlFunction, TableSqlFunction}
import org.apache.flink.table.functions.{AggregateFunction, EventTimeExtractor, RowTime, ScalarFunction, TableFunction, _}
import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
import scala.collection.JavaConversions._
import scala.collection.mutable
......@@ -242,15 +242,11 @@ object FunctionCatalog {
// array
"cardinality" -> classOf[ArrayCardinality],
"at" -> classOf[ArrayElementAt],
"element" -> classOf[ArrayElement],
"element" -> classOf[ArrayElement]
// TODO implement function overloading here
// "floor" -> classOf[TemporalFloor]
// "ceil" -> classOf[TemporalCeil]
// extensions to support streaming query
"rowtime" -> classOf[RowTime],
"proctime" -> classOf[ProcTime]
)
/**
......@@ -392,8 +388,6 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.ROUND,
SqlStdOperatorTable.PI,
// EXTENSIONS
EventTimeExtractor,
ProcTimeExtractor,
SqlStdOperatorTable.TUMBLE,
SqlStdOperatorTable.TUMBLE_START,
SqlStdOperatorTable.TUMBLE_END,
......
......@@ -405,15 +405,6 @@ public class TableEnvironmentITCase extends TableProgramsCollectionTestBase {
tableEnv.fromDataSet(dataSet, "nullField");
}
@Test(expected = TableException.class)
public void testAsWithToFewFields() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
// Must fail. Not enough field names specified.
tableEnv.fromDataSet(CollectionDataSets.get3TupleDataSet(env), "a, b");
}
@Test(expected = TableException.class)
public void testAsWithToManyFields() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
......
......@@ -93,7 +93,8 @@ class TableEnvironmentTest extends TableTestBase {
UnresolvedFieldReference("name1"),
UnresolvedFieldReference("name2"),
UnresolvedFieldReference("name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(0, 1, 2)).foreach(x => assertEquals(x._2, x._1))
......@@ -107,7 +108,8 @@ class TableEnvironmentTest extends TableTestBase {
UnresolvedFieldReference("name1"),
UnresolvedFieldReference("name2"),
UnresolvedFieldReference("name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(0, 1, 2)).foreach(x => assertEquals(x._2, x._1))
......@@ -121,7 +123,8 @@ class TableEnvironmentTest extends TableTestBase {
UnresolvedFieldReference("name1"),
UnresolvedFieldReference("name2"),
UnresolvedFieldReference("name3")
))
),
ignoreTimeAttributes = true)
}
@Test
......@@ -132,7 +135,8 @@ class TableEnvironmentTest extends TableTestBase {
UnresolvedFieldReference("pf3"),
UnresolvedFieldReference("pf1"),
UnresolvedFieldReference("pf2")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("pf3", "pf1", "pf2")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(2, 0, 1)).foreach(x => assertEquals(x._2, x._1))
......@@ -142,7 +146,8 @@ class TableEnvironmentTest extends TableTestBase {
def testGetFieldInfoAtomicName1(): Unit = {
val fieldInfo = tEnv.getFieldInfo(
atomicType,
Array(UnresolvedFieldReference("name"))
Array(UnresolvedFieldReference("name")),
ignoreTimeAttributes = true
)
fieldInfo._1.zip(Array("name")).foreach(x => assertEquals(x._2, x._1))
......@@ -156,7 +161,8 @@ class TableEnvironmentTest extends TableTestBase {
Array(
UnresolvedFieldReference("name1"),
UnresolvedFieldReference("name2")
))
),
ignoreTimeAttributes = true)
}
@Test
......@@ -167,7 +173,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("f0"), "name1"),
Alias(UnresolvedFieldReference("f1"), "name2"),
Alias(UnresolvedFieldReference("f2"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(0, 1, 2)).foreach(x => assertEquals(x._2, x._1))
......@@ -181,7 +188,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("f2"), "name1"),
Alias(UnresolvedFieldReference("f0"), "name2"),
Alias(UnresolvedFieldReference("f1"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(2, 0, 1)).foreach(x => assertEquals(x._2, x._1))
......@@ -195,7 +203,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("xxx"), "name1"),
Alias(UnresolvedFieldReference("yyy"), "name2"),
Alias(UnresolvedFieldReference("zzz"), "name3")
))
),
ignoreTimeAttributes = true)
}
@Test
......@@ -206,7 +215,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("cf1"), "name1"),
Alias(UnresolvedFieldReference("cf2"), "name2"),
Alias(UnresolvedFieldReference("cf3"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(0, 1, 2)).foreach(x => assertEquals(x._2, x._1))
......@@ -220,7 +230,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("cf3"), "name1"),
Alias(UnresolvedFieldReference("cf1"), "name2"),
Alias(UnresolvedFieldReference("cf2"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(2, 0, 1)).foreach(x => assertEquals(x._2, x._1))
......@@ -234,7 +245,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("xxx"), "name1"),
Alias(UnresolvedFieldReference("yyy"), "name2"),
Alias(UnresolvedFieldReference("zzz"), "name3")
))
),
ignoreTimeAttributes = true)
}
@Test
......@@ -245,7 +257,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("pf1"), "name1"),
Alias(UnresolvedFieldReference("pf2"), "name2"),
Alias(UnresolvedFieldReference("pf3"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(0, 1, 2)).foreach(x => assertEquals(x._2, x._1))
......@@ -259,7 +272,8 @@ class TableEnvironmentTest extends TableTestBase {
Alias(UnresolvedFieldReference("pf3"), "name1"),
Alias(UnresolvedFieldReference("pf1"), "name2"),
Alias(UnresolvedFieldReference("pf2"), "name3")
))
),
ignoreTimeAttributes = true)
fieldInfo._1.zip(Array("name1", "name2", "name3")).foreach(x => assertEquals(x._2, x._1))
fieldInfo._2.zip(Array(2, 0, 1)).foreach(x => assertEquals(x._2, x._1))
......@@ -272,8 +286,9 @@ class TableEnvironmentTest extends TableTestBase {
Array(
Alias(UnresolvedFieldReference("xxx"), "name1"),
Alias(UnresolvedFieldReference("yyy"), "name2"),
Alias( UnresolvedFieldReference("zzz"), "name3")
))
Alias(UnresolvedFieldReference("zzz"), "name3")
),
ignoreTimeAttributes = true)
}
@Test(expected = classOf[TableException])
......@@ -282,12 +297,16 @@ class TableEnvironmentTest extends TableTestBase {
atomicType,
Array(
Alias(UnresolvedFieldReference("name1"), "name2")
))
),
ignoreTimeAttributes = true)
}
@Test(expected = classOf[TableException])
def testGetFieldInfoGenericRowAlias(): Unit = {
tEnv.getFieldInfo(genericRowType, Array(UnresolvedFieldReference("first")))
tEnv.getFieldInfo(
genericRowType,
Array(UnresolvedFieldReference("first")),
ignoreTimeAttributes = true)
}
@Test
......
......@@ -207,16 +207,6 @@ class TableEnvironmentITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test(expected = classOf[TableException])
def testToTableWithToFewFields(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
CollectionDataSets.get3TupleDataSet(env)
// Must fail. Number of fields does not match.
.toTable(tEnv, 'a, 'b)
}
@Test(expected = classOf[TableException])
def testToTableWithToManyFields(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
......
......@@ -110,22 +110,6 @@ class CalcITCase extends StreamingMultipleProgramsTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test(expected = classOf[TableException])
def testAsWithToFewFields(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
StreamITCase.testResults = mutable.MutableList()
val ds = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b)
val results = ds.toDataStream[Row]
results.addSink(new StreamITCase.StringSink)
env.execute()
val expected = mutable.MutableList("no")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
@Test(expected = classOf[TableException])
def testAsWithToManyFields(): Unit = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册