提交 b80d89a6 编写于 作者: C chengxiang li 提交者: twalthr

[FLINK-2951] [Table API] add union operator to Table API.

This closes #1315.
上级 0ca425a6
......@@ -252,6 +252,11 @@ class JavaBatchTranslator extends PlanTranslator {
val inType = translatedInput.getType.asInstanceOf[CompositeType[Row]]
val filter = new ExpressionFilterFunction[Row](predicate, inType)
translatedInput.filter(filter).name(predicate.toString)
case uni@UnionAll(left, right) =>
val translatedLeft = translateInternal(left)
val translatedRight = translateInternal(right)
translatedLeft.union(translatedRight).name("Union: " + uni)
}
}
......
......@@ -197,6 +197,11 @@ class JavaStreamingTranslator extends PlanTranslator {
val inType = translatedInput.getType.asInstanceOf[CompositeType[Row]]
val filter = new ExpressionFilterFunction[Row](predicate, inType)
translatedInput.filter(filter)
case UnionAll(left, right) =>
val translatedLeft = translateInternal(left)
val translatedRight = translateInternal(right)
translatedLeft.union(translatedRight)
}
}
......
......@@ -243,5 +243,29 @@ case class Table(private[flink] val operation: PlanNode) {
this.copy(operation = Join(operation, right.operation))
}
/**
* Union two[[Table]]s. Similar to an SQL UNION ALL. The fields of the two union operations
* must fully overlap.
*
* Example:
*
* {{{
* left.unionAll(right)
* }}}
*/
def unionAll(right: Table): Table = {
val leftInputFields = operation.outputFields
val rightInputFields = right.operation.outputFields
if (!leftInputFields.equals(rightInputFields)) {
throw new ExpressionException(
"The fields names of join inputs should be fully overlapped, left inputs fields:" +
operation.outputFields.mkString(", ") +
" and right inputs fields" +
right.operation.outputFields.mkString(", ")
)
}
this.copy(operation = UnionAll(operation, right.operation))
}
override def toString: String = s"Expression($operation)"
}
......@@ -120,3 +120,14 @@ case class Aggregate(
override def toString = s"Aggregate($input, ${aggregations.mkString(",")})"
}
/**
* UnionAll operation, union all elements from left and right.
*/
case class UnionAll(left: PlanNode, right: PlanNode) extends PlanNode{
val children = Seq(left, right)
def outputFields = left.outputFields
override def toString = s"Union($left, $right)"
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.table.test;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.table.TableEnvironment;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.table.ExpressionException;
import org.apache.flink.api.table.Row;
import org.apache.flink.api.table.Table;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.List;
@RunWith(Parameterized.class)
public class UnionITCase extends MultipleProgramsTestBase {
public UnionITCase(TestExecutionMode mode) {
super(mode);
}
@Test
public void testUnion() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, c");
Table selected = in1.unionAll(in2).select("c");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();
String expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hi\n" + "Hello\n" + "Hello world\n";
compareResultAsText(results, expected);
}
@Test
public void testUnionWithFilter() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);
Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c");
Table selected = in1.unionAll(in2).where("b < 2").select("c");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();
String expected = "Hi\n" + "Hallo\n";
compareResultAsText(results, expected);
}
@Test(expected = ExpressionException.class)
public void testUnionFieldsNameNotOverlap1() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);
Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "d, e, f, g, h");
Table selected = in1.unionAll(in2);
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();
String expected = "";
compareResultAsText(results, expected);
}
@Test(expected = ExpressionException.class)
public void testUnionFieldsNameNotOverlap2() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);
Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, c, d, e").select("a, b, c");
Table selected = in1.unionAll(in2);
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();
String expected = "";
compareResultAsText(results, expected);
}
@Test
public void testUnionWithAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);
Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c");
Table selected = in1.unionAll(in2).select("c.count");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();
String expected = "18";
compareResultAsText(results, expected);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.table.test
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.table._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.api.table.{ExpressionException, Row}
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import scala.collection.JavaConversions
@RunWith(classOf[Parameterized])
class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
@Test
def testUnion(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val unionDs = ds1.unionAll(ds2).select('c)
val results = unionDs.toDataSet[Row].collect()
val expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hi\n" + "Hello\n" + "Hello world\n"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
@Test
def testUnionWithFilter(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)
val joinDs = ds1.unionAll(ds2.select('a, 'b, 'c)).filter('b < 2).select('c)
val results = joinDs.toDataSet[Row].collect()
val expected = "Hi\n" + "Hallo\n"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
@Test(expected = classOf[ExpressionException])
def testUnionFieldsNameNotOverlap1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)
val unionDs = ds1.unionAll(ds2)
val results = unionDs.toDataSet[Row].collect()
val expected = ""
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
@Test(expected = classOf[ExpressionException])
def testUnionFieldsNameNotOverlap2(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'c, 'd, 'e).select('a, 'b, 'c)
val unionDs = ds1.unionAll(ds2)
val results = unionDs.toDataSet[Row].collect()
val expected = ""
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
@Test
def testUnionWithAggregation(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)
val unionDs = ds1.unionAll(ds2.select('a, 'b, 'c)).select('c.count)
val results = unionDs.toDataSet[Row].collect()
val expected = "18"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册