提交 cbc5f830 编写于 作者: D dawidwys 提交者: Fabian Hueske

[Flink-2971] [tableAPI] Add outer joins to the Table API and SQL.

This closes #1981
上级 573a92fc
...@@ -423,11 +423,47 @@ Table result = in.groupBy("a").select("a, b.sum as d"); ...@@ -423,11 +423,47 @@ Table result = in.groupBy("a").select("a, b.sum as d");
<tr> <tr>
<td><strong>Join</strong></td> <td><strong>Join</strong></td>
<td> <td>
<p>Similar to a SQL JOIN clause. Joins two tables. Both tables must have distinct field names and an equality join predicate must be defined using a where or filter operator.</p> <p>Similar to a SQL JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined through join operator or using a where or filter operator.</p>
{% highlight java %} {% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c"); Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f"); Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.join(right).where("a = d").select("a, b, e"); Table result = left.join(right).where("a = d").select("a, b, e");
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>LeftOuterJoin</strong></td>
<td>
<p>Similar to a SQL LEFT OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.leftOuterJoin(right, "a = d").select("a, b, e");
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>RightOuterJoin</strong></td>
<td>
<p>Similar to a SQL RIGHT OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.rightOuterJoin(right, "a = d").select("a, b, e");
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>FullOuterJoin</strong></td>
<td>
<p>Similar to a SQL FULL OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.fullOuterJoin(right, "a = d").select("a, b, e");
{% endhighlight %} {% endhighlight %}
</td> </td>
</tr> </tr>
...@@ -551,6 +587,42 @@ val result = in.groupBy('a).select('a, 'b.sum as 'd); ...@@ -551,6 +587,42 @@ val result = in.groupBy('a).select('a, 'b.sum as 'd);
val left = ds1.toTable(tableEnv, 'a, 'b, 'c); val left = ds1.toTable(tableEnv, 'a, 'b, 'c);
val right = ds2.toTable(tableEnv, 'd, 'e, 'f); val right = ds2.toTable(tableEnv, 'd, 'e, 'f);
val result = left.join(right).where('a === 'd).select('a, 'b, 'e); val result = left.join(right).where('a === 'd).select('a, 'b, 'e);
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>LeftOuterJoin</strong></td>
<td>
<p>Similar to a SQL LEFT OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight scala %}
val left = tableEnv.fromDataSet(ds1, 'a, 'b, 'c)
val right = tableEnv.fromDataSet(ds2, 'd, 'e, 'f)
val result = left.leftOuterJoin(right, 'a === 'd).select('a, 'b, 'e)
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>RightOuterJoin</strong></td>
<td>
<p>Similar to a SQL RIGHT OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight scala %}
val left = tableEnv.fromDataSet(ds1, 'a, 'b, 'c)
val right = tableEnv.fromDataSet(ds2, 'd, 'e, 'f)
val result = left.rightOuterJoin(right, 'a === 'd).select('a, 'b, 'e)
{% endhighlight %}
</td>
</tr>
<tr>
<td><strong>FullOuterJoin</strong></td>
<td>
<p>Similar to a SQL FULL OUTER JOIN clause. Joins two tables. Both tables must have distinct field names and at least one equality join predicate must be defined.</p>
{% highlight scala %}
val left = tableEnv.fromDataSet(ds1, 'a, 'b, 'c)
val right = tableEnv.fromDataSet(ds2, 'd, 'e, 'f)
val result = left.fullOuterJoin(right, 'a === 'd).select('a, 'b, 'e)
{% endhighlight %} {% endhighlight %}
</td> </td>
</tr> </tr>
...@@ -711,13 +783,12 @@ Among others, the following SQL features are not supported, yet: ...@@ -711,13 +783,12 @@ Among others, the following SQL features are not supported, yet:
- Time data types (`DATE`, `TIME`, `TIMESTAMP`, `INTERVAL`) and `DECIMAL` types - Time data types (`DATE`, `TIME`, `TIMESTAMP`, `INTERVAL`) and `DECIMAL` types
- Distinct aggregates (e.g., `COUNT(DISTINCT name)`) - Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
- Outer joins
- Non-equi joins and Cartesian products - Non-equi joins and Cartesian products
- Result selection by order position (`ORDER BY OFFSET FETCH`) - Result selection by order position (`ORDER BY OFFSET FETCH`)
- Grouping sets - Grouping sets
- Set operations except `UNION ALL` (`INTERSECT`, `UNION`, `EXCEPT`) - `INTERSECT` and `EXCEPT` set operations
*Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products. Certain rewrites during optimization (e.g., subquery decorrelation) can result in unsupported operations such as outer joins.* *Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products.*
### SQL on Streaming Tables ### SQL on Streaming Tables
......
...@@ -19,14 +19,14 @@ package org.apache.flink.api.table.plan.logical ...@@ -19,14 +19,14 @@ package org.apache.flink.api.table.plan.logical
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.logical.LogicalProject import org.apache.calcite.rel.logical.LogicalProject
import org.apache.calcite.rex.{RexInputRef, RexNode}
import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table._ import org.apache.flink.api.table._
import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.expressions._
...@@ -269,22 +269,62 @@ case class Join( ...@@ -269,22 +269,62 @@ case class Join(
condition: Option[Expression]) extends BinaryNode { condition: Option[Expression]) extends BinaryNode {
override def output: Seq[Attribute] = { override def output: Seq[Attribute] = {
joinType match { left.output ++ right.output
case JoinType.INNER => left.output ++ right.output }
case j => throw new ValidationException(s"Unsupported JoinType: $j")
private case class JoinFieldReference(
name: String,
resultType: TypeInformation[_],
left: LogicalNode,
right: LogicalNode) extends Attribute {
val isFromLeftInput = left.output.map(_.name).contains(name)
val (indexInInput, indexInJoin) = if (isFromLeftInput) {
val indexInLeft = left.output.map(_.name).indexOf(name)
(indexInLeft, indexInLeft)
} else {
val indexInRight = right.output.map(_.name).indexOf(name)
(indexInRight, indexInRight + left.output.length)
}
override def toString = s"'$name"
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
// look up type of field
val fieldType = relBuilder.field(2, if (isFromLeftInput) 0 else 1, name).getType
// create a new RexInputRef with index offset
new RexInputRef(indexInJoin, fieldType)
}
override def withName(newName: String): Attribute = {
if (newName == name) {
this
} else {
JoinFieldReference(newName, resultType, left, right)
}
} }
} }
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
joinType match { val node = super.resolveExpressions(tableEnv).asInstanceOf[Join]
case JoinType.INNER => val partialFunction: PartialFunction[Expression, Expression] = {
left.construct(relBuilder) case field: ResolvedFieldReference => JoinFieldReference(
right.construct(relBuilder) field.name,
relBuilder.join(JoinRelType.INNER, field.resultType,
condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true))) left,
case _ => right)
throw new ValidationException(s"Unsupported JoinType: $joinType")
} }
val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
new Join(node.left, node.right, node.joinType, resolvedCondition)
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
relBuilder.join(
TypeConverter.flinkJoinTypeToRelType(joinType),
condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)))
} }
private def ambiguousName: Set[String] = private def ambiguousName: Set[String] =
...@@ -298,11 +338,42 @@ case class Join( ...@@ -298,11 +338,42 @@ case class Join(
val resolvedJoin = super.validate(tableEnv).asInstanceOf[Join] val resolvedJoin = super.validate(tableEnv).asInstanceOf[Join]
if (!resolvedJoin.condition.forall(_.resultType == BOOLEAN_TYPE_INFO)) { if (!resolvedJoin.condition.forall(_.resultType == BOOLEAN_TYPE_INFO)) {
failValidation(s"filter expression ${resolvedJoin.condition} is not a boolean") failValidation(s"filter expression ${resolvedJoin.condition} is not a boolean")
} else if (!ambiguousName.isEmpty) { } else if (ambiguousName.nonEmpty) {
failValidation(s"join relations with ambiguous names: ${ambiguousName.mkString(", ")}") failValidation(s"join relations with ambiguous names: ${ambiguousName.mkString(", ")}")
} }
resolvedJoin.condition.foreach(testJoinCondition(_))
resolvedJoin resolvedJoin
} }
private def testJoinCondition(expression: Expression): Unit = {
def checkIfJoinCondition(exp : BinaryComparison) = exp.children match {
case (x : JoinFieldReference) :: (y : JoinFieldReference) :: Nil
if x.isFromLeftInput != y.isFromLeftInput => Unit
case x => failValidation(
s"Invalid non-join predicate $exp. For non-join predicates use Table#where.")
}
var equiJoinFound = false
def validateConditions(exp: Expression, isAndBranch: Boolean): Unit = exp match {
case x: And => x.children.foreach(validateConditions(_, isAndBranch))
case x: Or => x.children.foreach(validateConditions(_, isAndBranch = false))
case x: EqualTo =>
if (isAndBranch) {
equiJoinFound = true
}
checkIfJoinCondition(x)
case x: BinaryComparison => checkIfJoinCondition(x)
case x => failValidation(
s"Unsupported condition type: ${x.getClass.getSimpleName}. Condition: $x")
}
validateConditions(expression, isAndBranch = true)
if (!equiJoinFound) {
failValidation(s"Invalid join condition: $expression. At least one equi-join required.")
}
}
} }
case class CatalogNode( case class CatalogNode(
......
...@@ -20,24 +20,23 @@ package org.apache.flink.api.table.plan.nodes.dataset ...@@ -20,24 +20,23 @@ package org.apache.flink.api.table.plan.nodes.dataset
import org.apache.calcite.plan._ import org.apache.calcite.plan._
import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinInfo import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelWriter, BiRel, RelNode} import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
import org.apache.calcite.util.mapping.IntPair import org.apache.calcite.util.mapping.IntPair
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.codegen.CodeGenerator import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.runtime.FlatJoinRunner import org.apache.flink.api.table.runtime.FlatJoinRunner
import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType
import org.apache.flink.api.table.{BatchTableEnvironment, TableException} import org.apache.flink.api.table.{BatchTableEnvironment, TableException}
import org.apache.flink.api.common.functions.FlatJoinFunction import org.apache.flink.api.common.functions.FlatJoinFunction
import scala.collection.mutable.ArrayBuffer
import org.apache.calcite.rex.RexNode import org.apache.calcite.rex.RexNode
import scala.collection.JavaConverters._
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
/** /**
* Flink RelNode which matches along with JoinOperator and its related operations. * Flink RelNode which matches along with JoinOperator and its related operations.
...@@ -52,7 +51,7 @@ class DataSetJoin( ...@@ -52,7 +51,7 @@ class DataSetJoin(
joinRowType: RelDataType, joinRowType: RelDataType,
joinInfo: JoinInfo, joinInfo: JoinInfo,
keyPairs: List[IntPair], keyPairs: List[IntPair],
joinType: JoinType, joinType: JoinRelType,
joinHint: JoinHint, joinHint: JoinHint,
ruleDescription: String) ruleDescription: String)
extends BiRel(cluster, traitSet, left, right) extends BiRel(cluster, traitSet, left, right)
...@@ -77,13 +76,14 @@ class DataSetJoin( ...@@ -77,13 +76,14 @@ class DataSetJoin(
} }
override def toString: String = { override def toString: String = {
s"Join(where: ($joinConditionToString), join: ($joinSelectionToString))" s"$joinTypeToString(where: ($joinConditionToString), join: ($joinSelectionToString))"
} }
override def explainTerms(pw: RelWriter): RelWriter = { override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw) super.explainTerms(pw)
.item("where", joinConditionToString) .item("where", joinConditionToString)
.item("join", joinSelectionToString) .item("join", joinSelectionToString)
.item("joinType", joinTypeToString)
} }
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
...@@ -148,9 +148,20 @@ class DataSetJoin( ...@@ -148,9 +148,20 @@ class DataSetJoin(
val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv) val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv) val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val (joinOperator, nullCheck) = joinType match {
case JoinRelType.INNER => (leftDataSet.join(rightDataSet), false)
case JoinRelType.LEFT => (leftDataSet.leftOuterJoin(rightDataSet), true)
case JoinRelType.RIGHT => (leftDataSet.rightOuterJoin(rightDataSet), true)
case JoinRelType.FULL => (leftDataSet.fullOuterJoin(rightDataSet), true)
}
if (nullCheck && !config.getNullCheck) {
throw new TableException("Null check in TableConfig must be enabled for outer joins.")
}
val generator = new CodeGenerator( val generator = new CodeGenerator(
config, config,
false, nullCheck,
leftDataSet.getType, leftDataSet.getType,
Some(rightDataSet.getType)) Some(rightDataSet.getType))
val conversion = generator.generateConverterResultExpression( val conversion = generator.generateConverterResultExpression(
...@@ -189,7 +200,7 @@ class DataSetJoin( ...@@ -189,7 +200,7 @@ class DataSetJoin(
val joinOpName = s"where: ($joinConditionToString), join: ($joinSelectionToString)" val joinOpName = s"where: ($joinConditionToString), join: ($joinSelectionToString)"
leftDataSet.join(rightDataSet).where(leftKeys.toArray: _*).equalTo(rightKeys.toArray: _*) joinOperator.where(leftKeys.toArray: _*).equalTo(rightKeys.toArray: _*)
.`with`(joinFun).name(joinOpName).asInstanceOf[DataSet[Any]] .`with`(joinFun).name(joinOpName).asInstanceOf[DataSet[Any]]
} }
...@@ -203,4 +214,11 @@ class DataSetJoin( ...@@ -203,4 +214,11 @@ class DataSetJoin(
getExpressionString(joinCondition, inFields, None) getExpressionString(joinCondition, inFields, None)
} }
private def joinTypeToString = joinType match {
case JoinRelType.INNER => "Join"
case JoinRelType.LEFT=> "LeftOuterJoin"
case JoinRelType.RIGHT => "RightOuterJoin"
case JoinRelType.FULL => "FullOuterJoin"
}
} }
...@@ -18,13 +18,11 @@ ...@@ -18,13 +18,11 @@
package org.apache.flink.api.table.plan.rules.dataSet package org.apache.flink.api.table.plan.rules.dataSet
import org.apache.calcite.plan.{RelOptRuleCall, Convention, RelOptRule, RelTraitSet} import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.logical.LogicalJoin import org.apache.calcite.rel.logical.LogicalJoin
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.TableException
import org.apache.flink.api.table.plan.nodes.dataset.{DataSetJoin, DataSetConvention} import org.apache.flink.api.table.plan.nodes.dataset.{DataSetJoin, DataSetConvention}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -42,15 +40,7 @@ class DataSetJoinRule ...@@ -42,15 +40,7 @@ class DataSetJoinRule
val joinInfo = join.analyzeCondition val joinInfo = join.analyzeCondition
// joins require an equi-condition or a conjunctive predicate with at least one equi-condition // joins require an equi-condition or a conjunctive predicate with at least one equi-condition
val hasValidCondition = !joinInfo.pairs().isEmpty !joinInfo.pairs().isEmpty
// only inner joins are supported at the moment
val isInnerJoin = join.getJoinType.equals(JoinRelType.INNER)
if (!isInnerJoin) {
throw new TableException("OUTER JOIN is currently not supported.")
}
// check that condition is valid and inner join
hasValidCondition && isInnerJoin
} }
override def convert(rel: RelNode): RelNode = { override def convert(rel: RelNode): RelNode = {
...@@ -71,10 +61,11 @@ class DataSetJoinRule ...@@ -71,10 +61,11 @@ class DataSetJoinRule
join.getRowType, join.getRowType,
joinInfo, joinInfo,
joinInfo.pairs.toList, joinInfo.pairs.toList,
JoinType.INNER, join.getJoinType,
null, null,
description) description)
} }
} }
object DataSetJoinRule { object DataSetJoinRule {
......
...@@ -18,12 +18,10 @@ ...@@ -18,12 +18,10 @@
package org.apache.flink.api.table package org.apache.flink.api.table
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.plan.RexNodeTranslator.extractAggregations import org.apache.flink.api.table.plan.RexNodeTranslator.extractAggregations
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.expressions._
import org.apache.flink.api.table.plan.logical._ import org.apache.flink.api.table.plan.logical._
import org.apache.flink.api.table.sinks.TableSink import org.apache.flink.api.table.sinks.TableSink
...@@ -251,12 +249,157 @@ class Table( ...@@ -251,12 +249,157 @@ class Table(
* }}} * }}}
*/ */
def join(right: Table): Table = { def join(right: Table): Table = {
join(right, None, JoinType.INNER)
}
/**
* Joins two [[Table]]s. Similar to an SQL join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]].
*
* Example:
*
* {{{
* left.join(right, "a = b")
* }}}
*/
def join(right: Table, joinPredicate: String): Table = {
join(right, joinPredicate, JoinType.INNER)
}
/**
* Joins two [[Table]]s. Similar to an SQL join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]].
*
* Example:
*
* {{{
* left.join(right, 'a === 'b).select('a, 'b, 'd)
* }}}
*/
def join(right: Table, joinPredicate: Expression): Table = {
join(right, Some(joinPredicate), JoinType.INNER)
}
/**
* Joins two [[Table]]s. Similar to an SQL left outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.leftOuterJoin(right, "a = b").select('a, 'b, 'd)
* }}}
*/
def leftOuterJoin(right: Table, joinPredicate: String): Table = {
join(right, joinPredicate, JoinType.LEFT_OUTER)
}
/**
* Joins two [[Table]]s. Similar to an SQL left outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.leftOuterJoin(right, 'a === 'b).select('a, 'b, 'd)
* }}}
*/
def leftOuterJoin(right: Table, joinPredicate: Expression): Table = {
join(right, Some(joinPredicate), JoinType.LEFT_OUTER)
}
/**
* Joins two [[Table]]s. Similar to an SQL right outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.rightOuterJoin(right, "a = b").select('a, 'b, 'd)
* }}}
*/
def rightOuterJoin(right: Table, joinPredicate: String): Table = {
join(right, joinPredicate, JoinType.RIGHT_OUTER)
}
/**
* Joins two [[Table]]s. Similar to an SQL right outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.rightOuterJoin(right, 'a === 'b).select('a, 'b, 'd)
* }}}
*/
def rightOuterJoin(right: Table, joinPredicate: Expression): Table = {
join(right, Some(joinPredicate), JoinType.RIGHT_OUTER)
}
/**
* Joins two [[Table]]s. Similar to an SQL full outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.fullOuterJoin(right, "a = b").select('a, 'b, 'd)
* }}}
*/
def fullOuterJoin(right: Table, joinPredicate: String): Table = {
join(right, joinPredicate, JoinType.FULL_OUTER)
}
/**
* Joins two [[Table]]s. Similar to an SQL full outer join. The fields of the two joined
* operations must not overlap, use [[as]] to rename fields if necessary.
*
* Note: Both tables must be bound to the same [[TableEnvironment]] and its [[TableConfig]] must
* have nullCheck enabled.
*
* Example:
*
* {{{
* left.fullOuterJoin(right, 'a === 'b).select('a, 'b, 'd)
* }}}
*/
def fullOuterJoin(right: Table, joinPredicate: Expression): Table = {
join(right, Some(joinPredicate), JoinType.FULL_OUTER)
}
private def join(right: Table, joinPredicate: String, joinType: JoinType): Table = {
val joinPredicateExpr = ExpressionParser.parseExpression(joinPredicate)
join(right, Some(joinPredicateExpr), joinType)
}
private def join(right: Table, joinPredicate: Option[Expression], joinType: JoinType): Table = {
// check that right table belongs to the same TableEnvironment // check that right table belongs to the same TableEnvironment
if (right.tableEnv != this.tableEnv) { if (right.tableEnv != this.tableEnv) {
throw new ValidationException("Only tables from the same TableEnvironment can be joined.") throw new ValidationException("Only tables from the same TableEnvironment can be joined.")
} }
new Table(tableEnv, new Table(
Join(this.logicalPlan, right.logicalPlan, JoinType.INNER, None).validate(tableEnv)) tableEnv,
Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate).validate(tableEnv))
} }
/** /**
......
...@@ -197,4 +197,11 @@ object TypeConverter { ...@@ -197,4 +197,11 @@ object TypeConverter {
case RIGHT => JoinType.RIGHT_OUTER case RIGHT => JoinType.RIGHT_OUTER
case FULL => JoinType.FULL_OUTER case FULL => JoinType.FULL_OUTER
} }
def flinkJoinTypeToRelType(joinType: JoinType) = joinType match {
case JoinType.INNER => JoinRelType.INNER
case JoinType.LEFT_OUTER => JoinRelType.LEFT
case JoinType.RIGHT_OUTER => JoinRelType.RIGHT
case JoinType.FULL_OUTER => JoinRelType.FULL
}
} }
...@@ -245,11 +245,12 @@ class JoinITCase( ...@@ -245,11 +245,12 @@ class JoinITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
@Test(expected = classOf[TableException]) @Test
def testFullOuterJoin(): Unit = { def testFullOuterJoin(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config) val tEnv = TableEnvironment.getTableEnvironment(env, config)
tEnv.getConfig.setNullCheck(true)
val sqlQuery = "SELECT c, g FROM Table3 FULL OUTER JOIN Table5 ON b = e" val sqlQuery = "SELECT c, g FROM Table3 FULL OUTER JOIN Table5 ON b = e"
...@@ -258,16 +259,23 @@ class JoinITCase( ...@@ -258,16 +259,23 @@ class JoinITCase(
tEnv.registerTable("Table3", ds1) tEnv.registerTable("Table3", ds1)
tEnv.registerTable("Table5", ds2) tEnv.registerTable("Table5", ds2)
tEnv.sql(sqlQuery).toDataSet[Row].collect() val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" +
"null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" +
"null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" +
"null,IJK\n" + "null,JKL\n" + "null,KLM"
val results = tEnv.sql(sqlQuery).toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
@Test(expected = classOf[TableException]) @Test
def testLeftOuterJoin(): Unit = { def testLeftOuterJoin(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config) val tEnv = TableEnvironment.getTableEnvironment(env, config)
tEnv.getConfig.setNullCheck(true)
val sqlQuery = "SELECT c, g FROM Table3 LEFT OUTER JOIN Table5 ON b = e" val sqlQuery = "SELECT c, g FROM Table5 LEFT OUTER JOIN Table3 ON b = e"
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('d, 'e, 'f, 'g, 'h) val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('d, 'e, 'f, 'g, 'h)
...@@ -275,13 +283,21 @@ class JoinITCase( ...@@ -275,13 +283,21 @@ class JoinITCase(
tEnv.registerTable("Table5", ds2) tEnv.registerTable("Table5", ds2)
tEnv.sql(sqlQuery).toDataSet[Row].collect() tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" +
"null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" +
"null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" +
"null,IJK\n" + "null,JKL\n" + "null,KLM"
val results = tEnv.sql(sqlQuery).toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
@Test(expected = classOf[TableException]) @Test
def testRightOuterJoin(): Unit = { def testRightOuterJoin(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config) val tEnv = TableEnvironment.getTableEnvironment(env, config)
tEnv.getConfig.setNullCheck(true)
val sqlQuery = "SELECT c, g FROM Table3 RIGHT OUTER JOIN Table5 ON b = e" val sqlQuery = "SELECT c, g FROM Table3 RIGHT OUTER JOIN Table5 ON b = e"
...@@ -291,5 +307,12 @@ class JoinITCase( ...@@ -291,5 +307,12 @@ class JoinITCase(
tEnv.registerTable("Table5", ds2) tEnv.registerTable("Table5", ds2)
tEnv.sql(sqlQuery).toDataSet[Row].collect() tEnv.sql(sqlQuery).toDataSet[Row].collect()
val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" +
"null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" +
"null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" +
"null,IJK\n" + "null,JKL\n" + "null,KLM"
val results = tEnv.sql(sqlQuery).toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
} }
...@@ -178,7 +178,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -178,7 +178,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
val joinT = ds1.join(ds2).where('a === 'd).select('g.count) val joinT = ds1.join(ds2).where('a === 'd).select('g.count)
val expected = "6" val expected = "6"
val results = joinT.toDataSet[Row]collect() val results = joinT.toDataSet[Row] collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
...@@ -196,7 +196,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -196,7 +196,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
.select('b.sum, 'g.count) .select('b.sum, 'g.count)
val expected = "6,3\n" + "4,2\n" + "1,1" val expected = "6,3\n" + "4,2\n" + "1,1"
val results = joinT.toDataSet[Row]collect() val results = joinT.toDataSet[Row] collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
...@@ -216,7 +216,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -216,7 +216,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
.select('a, 'f, 'l) .select('a, 'f, 'l)
val expected = "2,1,Hello\n" + "2,1,Hello world\n" + "1,0,Hi" val expected = "2,1,Hello\n" + "2,1,Hello world\n" + "1,0,Hi"
val results = joinT.toDataSet[Row]collect() val results = joinT.toDataSet[Row] collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
...@@ -228,13 +228,12 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -228,13 +228,12 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds1.join(ds2).filter('a === 'd && ('b === 'e || 'b === 'e - 10 )).select('c, 'g) val joinT = ds1.join(ds2).filter('a === 'd && ('b === 'e || 'b === 'e - 10)).select('c, 'g)
val expected = val expected = "Hi,Hallo\n" +
"Hi,Hallo\n" + "Hello,Hallo Welt\n" +
"Hello,Hallo Welt\n" + "I am fine.,IJK"
"I am fine.,IJK" val results = joinT.toDataSet[Row] collect()
val results = joinT.toDataSet[Row]collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
...@@ -248,13 +247,12 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -248,13 +247,12 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
val joinT = ds1.join(ds2).filter('b === 'h + 1 && 'a - 1 === 'd + 2).select('c, 'g) val joinT = ds1.join(ds2).filter('b === 'h + 1 && 'a - 1 === 'd + 2).select('c, 'g)
val expected = val expected = "I am fine.,Hallo Welt\n" +
"I am fine.,Hallo Welt\n" + "Luke Skywalker,Hallo Welt wie gehts?\n" +
"Luke Skywalker,Hallo Welt wie gehts?\n" + "Luke Skywalker,ABC\n" +
"Luke Skywalker,ABC\n" + "Comment#2,HIJ\n" +
"Comment#2,HIJ\n" + "Comment#2,IJK"
"Comment#2,IJK" val results = joinT.toDataSet[Row] collect()
val results = joinT.toDataSet[Row]collect()
TestBaseUtils.compareResultAsText(results.asJava, expected) TestBaseUtils.compareResultAsText(results.asJava, expected)
} }
...@@ -271,4 +269,109 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) ...@@ -271,4 +269,109 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
ds1.join(ds2).where('b === 'e).select('c, 'g) ds1.join(ds2).where('b === 'e).select('c, 'g)
} }
@Test
def testLeftJoinWithMultipleKeys(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds1.leftOuterJoin(ds2, 'a === 'd && 'b === 'h).select('c, 'g)
val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt wie gehts?\n" +
"Hello world,ABC\n" + "Hello world, how are you?,null\n" + "I am fine.,HIJ\n" +
"I am fine.,IJK\n" + "Luke Skywalker,null\n" + "Comment#1,null\n" + "Comment#2,null\n" +
"Comment#3,null\n" + "Comment#4,null\n" + "Comment#5,null\n" + "Comment#6,null\n" +
"Comment#7,null\n" + "Comment#8,null\n" + "Comment#9,null\n" + "Comment#10,null\n" +
"Comment#11,null\n" + "Comment#12,null\n" + "Comment#13,null\n" + "Comment#14,null\n" +
"Comment#15,null\n"
val results = joinT.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test(expected = classOf[ValidationException])
def testNoJoinCondition(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds2.leftOuterJoin(ds1, 'b === 'd && 'b < 3).select('c, 'g)
}
@Test(expected = classOf[ValidationException])
def testNoEquiJoin(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds2.leftOuterJoin(ds1, 'b < 'd).select('c, 'g)
}
@Test
def testRightJoinWithMultipleKeys(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds1.rightOuterJoin(ds2, "a = d && b = h").select('c, 'g)
val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "null,Hallo Welt wie\n" +
"Hello world,Hallo Welt wie gehts?\n" + "Hello world,ABC\n" + "null,BCD\n" + "null,CDE\n" +
"null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "I am fine.,HIJ\n" +
"I am fine.,IJK\n" + "null,JKL\n" + "null,KLM\n"
val results = joinT.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testRightJoinWithNotOnlyEquiJoin(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds1.rightOuterJoin(ds2, "a = d && b < h").select('c, 'g)
val expected = "Hello world,BCD\n"
val results = joinT.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
@Test
def testFullOuterJoinWithMultipleKeys(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
tEnv.getConfig.setNullCheck(true)
val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h)
val joinT = ds1.fullOuterJoin(ds2, 'a === 'd && 'b === 'h).select('c, 'g)
val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "null,Hallo Welt wie\n" +
"Hello world,Hallo Welt wie gehts?\n" + "Hello world,ABC\n" + "null,BCD\n" + "null,CDE\n" +
"null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "I am fine.,HIJ\n" +
"I am fine.,IJK\n" + "null,JKL\n" + "null,KLM\n" + "Luke Skywalker,null\n" +
"Comment#1,null\n" + "Comment#2,null\n" + "Comment#3,null\n" + "Comment#4,null\n" +
"Comment#5,null\n" + "Comment#6,null\n" + "Comment#7,null\n" + "Comment#8,null\n" +
"Comment#9,null\n" + "Comment#10,null\n" + "Comment#11,null\n" + "Comment#12,null\n" +
"Comment#13,null\n" + "Comment#14,null\n" + "Comment#15,null\n" +
"Hello world, how are you?,null\n"
val results = joinT.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册