From 45e680c2b6c9c2f64ce55423b755a13d402ff8ba Mon Sep 17 00:00:00 2001 From: Chiwan Park Date: Mon, 6 Apr 2015 05:07:11 +0900 Subject: [PATCH] [FLINK-703] [scala api] Use complete element as join key This closes #572 --- .../apache/flink/api/java/operators/Keys.java | 4 +- .../org/apache/flink/api/scala/DataSet.scala | 1 - .../scala/unfinishedKeyPairOperation.scala | 5 -- .../api/scala/operators/CoGroupITCase.scala | 42 +++++++++++ .../scala/operators/CoGroupOperatorTest.scala | 57 ++++++++++++++ .../scala/operators/GroupReduceITCase.scala | 9 +++ .../api/scala/operators/GroupingTest.scala | 38 +++++++++- .../api/scala/operators/JoinITCase.scala | 22 ++++++ .../scala/operators/JoinOperatorTest.scala | 74 +++++++++++++++++-- 9 files changed, 236 insertions(+), 16 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java index ee233e8e48f..69d306f8912 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java @@ -280,7 +280,7 @@ public abstract class Keys { if (!type.isKeyType()) { throw new InvalidProgramException("This type (" + type + ") cannot be used as key."); } else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) { - throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'."); + throw new InvalidProgramException("Field expression for atomic type must be equal to '*' or '_'."); } keyFields = new ArrayList(1); @@ -297,7 +297,7 @@ public abstract class Keys { for (int i = 0; i < expressions.length; i++) { List keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check if(keys.size() == 0) { - throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); + throw new InvalidProgramException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); } keyFields.addAll(keys); } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 3b80a23f604..56762294725 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -807,7 +807,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { * This will not create a new DataSet, it will just attach the field names which will be * used for grouping when executing a grouped operation. * - * This only works on CaseClass DataSets. */ def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = { new GroupedDataSet[T]( diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala index de03687595d..08d02424177 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala @@ -65,8 +65,6 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( * a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the * key for the right side. The result after specifying the right side key is the finished * operation. - * - * This only works on a CaseClass [[DataSet]]. */ def where(firstLeftField: String, otherLeftFields: String*) = { val leftKey = new ExpressionKeys[L]( @@ -113,8 +111,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( /** * Specify the key fields for the right side of the key based operation. This returns * the finished operation. - * - * This only works on a CaseClass [[DataSet]]. */ def equalTo(firstRightField: String, otherRightFields: String*): O = { val rightKey = new ExpressionKeys[R]( @@ -125,7 +121,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( leftKey + " Right: " + rightKey) } unfinished.finish(leftKey, rightKey) - } /** diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupITCase.scala index 42ec42be4ae..3379fe2d0b8 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupITCase.scala @@ -383,5 +383,47 @@ class CoGroupITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo env.execute() expectedResult = "-1,20000,Flink\n" + "-1,10000,Flink\n" + "-1,30000,Flink\n" } + + @Test + def testCoGroupWithAtomic1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = env.fromElements(0, 1, 2) + val coGroupDs = ds1.coGroup(ds2).where(0).equalTo("*") { + (first, second, out: Collector[(Int, Long, String)]) => + for (p <- first) { + for (t <- second) { + if (p._1 == t) { + out.collect(p) + } + } + } + } + + coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE) + env.execute() + expectedResult = "(1,1,Hi)\n(2,2,Hello)" + } + + @Test + def testCoGroupWithAtomic2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(0, 1, 2) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env) + val coGroupDs = ds1.coGroup(ds2).where("*").equalTo(0) { + (first, second, out: Collector[(Int, Long, String)]) => + for (p <- first) { + for (t <- second) { + if (p == t._1) { + out.collect(t) + } + } + } + } + + coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE) + env.execute() + expectedResult = "(1,1,Hi)\n(2,2,Hello)" + } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala index 115ca35ce10..4f052dd11de 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala @@ -17,6 +17,9 @@ */ package org.apache.flink.api.scala.operators +import java.util + +import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.junit.Assert import org.junit.Test @@ -268,6 +271,60 @@ class CoGroupOperatorTest { // Should not work, more than one field position key ds1.coGroup(ds2).where(1, 3).equalTo { _.myLong } } + + @Test + def testCoGroupWithAtomic1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyTupleData) + val ds2 = env.fromElements(0, 1, 2) + + ds1.coGroup(ds2).where(0).equalTo("*") + } + + @Test + def testCoGroupWithAtomic2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(0, 1, 2) + val ds2 = env.fromCollection(emptyTupleData) + + ds1.coGroup(ds2).where("*").equalTo(0) + } + + @Test(expected = classOf[InvalidProgramException]) + def testCoGroupWithInvalidAtomic1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(0, 1, 2) + val ds2 = env.fromCollection(emptyTupleData) + + ds1.coGroup(ds2).where("invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testCoGroupWithInvalidAtomic2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyTupleData) + val ds2 = env.fromElements(0, 1, 2) + + ds1.coGroup(ds2).where(0).equalTo("invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testCoGroupWithInvalidAtomic3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(new util.ArrayList[Integer]()) + val ds2 = env.fromElements(0, 0, 0) + + ds1.coGroup(ds2).where("*") + } + + @Test(expected = classOf[InvalidProgramException]) + def testCoGroupWithInvalidAtomic4(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(0, 0, 0) + val ds2 = env.fromElements(new util.ArrayList[Integer]()) + + ds1.coGroup(ds2).where("*").equalTo("*") + } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala index b6f045b8d57..fe9e3f34505 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala @@ -743,6 +743,15 @@ class GroupReduceITCase(mode: TestExecutionMode) extends MultipleProgramsTestBas expected = "b\nccc\nee\n" } + @Test + def testWithAtomic1: Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements(0, 1, 1, 2) + val reduceDs = ds.groupBy("*").reduceGroup((ints: Iterator[Int]) => ints.next()) + reduceDs.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + expected = "0\n1\n2" + } } @RichGroupReduceFunction.Combinable diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala index a816ae869b5..33309295860 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala @@ -17,6 +17,8 @@ */ package org.apache.flink.api.scala.operators +import java.util + import org.apache.flink.api.scala.util.CollectionDataSets.CustomType import org.junit.Assert import org.apache.flink.api.common.InvalidProgramException @@ -96,7 +98,7 @@ class GroupingTest { } } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[InvalidProgramException]) def testGroupByKeyFields2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val longDs = env.fromCollection(emptyLongData) @@ -146,7 +148,7 @@ class GroupingTest { } } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[InvalidProgramException]) def testGroupByKeyExpressions2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment @@ -224,5 +226,37 @@ class GroupingTest { case e: Exception => Assert.fail() } } + + @Test + def testAtomicValue1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements(0, 1, 2) + + ds.groupBy("*") + } + + @Test(expected = classOf[InvalidProgramException]) + def testAtomicValueInvalid1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements(0, 1, 2) + + ds.groupBy("invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testAtomicValueInvalid2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements(0, 1, 2) + + ds.groupBy("_", "invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testAtomicValueInvalid3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements(new util.ArrayList[Integer]()) + + ds.groupBy("*") + } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala index 9bf3ccea754..4135ab29ada 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala @@ -384,4 +384,26 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) "2 Second (20,200,2000,Two) 20000,(20000,20,200,2000,Two,2,Second)\n" + "3 Third (30,300,3000,Three) 30000,(30000,30,300,3000,Three,3,Third)\n" } + + @Test + def testWithAtomic1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = env.fromElements(0, 1, 2) + val joinDs = ds1.join(ds2).where(0).equalTo("*") + joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE) + env.execute() + expected = "(1,1,Hi),1\n(2,2,Hello),2" + } + + @Test + def testWithAtomic2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(0, 1, 2) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env) + val joinDs = ds1.join(ds2).where("*").equalTo(0) + joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE) + env.execute() + expected = "1,(1,1,Hi)\n2,(2,2,Hello)" + } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala index 017d78df8b9..8f36c609d38 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala @@ -17,14 +17,13 @@ */ package org.apache.flink.api.scala.operators -import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException -import org.apache.flink.api.scala.util.CollectionDataSets.CustomType -import org.junit.Assert -import org.apache.flink.api.common.InvalidProgramException -import org.junit.Ignore -import org.junit.Test +import java.util +import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets.CustomType +import org.junit.{Assert, Test} class JoinOperatorTest { @@ -272,5 +271,68 @@ class JoinOperatorTest { // should not work, more than one field position key ds1.join(ds2).where(1, 3) equalTo { _.myLong } } + + @Test + def testJoinWithAtomic(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyTupleData) + val ds2 = env.fromCollection(emptyLongData) + + ds1.join(ds2).where(1).equalTo("*") + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyTupleData) + val ds2 = env.fromCollection(emptyLongData) + + ds1.join(ds2).where(1).equalTo("invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyLongData) + val ds2 = env.fromCollection(emptyTupleData) + + ds1.join(ds2).where("invalidKey").equalTo(1) + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyTupleData) + val ds2 = env.fromCollection(emptyLongData) + + ds1.join(ds2).where(1).equalTo("_", "invalidKey") + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic4(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyLongData) + val ds2 = env.fromCollection(emptyTupleData) + + ds1.join(ds2).where("_", "invalidKey").equalTo(1) + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic5(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromElements(new util.ArrayList[Integer]()) + val ds2 = env.fromCollection(emptyLongData) + + ds1.join(ds2).where("*") + } + + @Test(expected = classOf[InvalidProgramException]) + def testJoinWithInvalidAtomic6(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds1 = env.fromCollection(emptyLongData) + val ds2 = env.fromElements(new util.ArrayList[Integer]()) + + ds1.join(ds2).where("*").equalTo("*") + } } -- GitLab