提交 45e680c2 编写于 作者: C Chiwan Park 提交者: Fabian Hueske

[FLINK-703] [scala api] Use complete element as join key

This closes #572
上级 30a74c76
......@@ -280,7 +280,7 @@ public abstract class Keys<T> {
if (!type.isKeyType()) {
throw new InvalidProgramException("This type (" + type + ") cannot be used as key.");
} else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) {
throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'.");
throw new InvalidProgramException("Field expression for atomic type must be equal to '*' or '_'.");
}
keyFields = new ArrayList<FlatFieldDescriptor>(1);
......@@ -297,7 +297,7 @@ public abstract class Keys<T> {
for (int i = 0; i < expressions.length; i++) {
List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check
if(keys.size() == 0) {
throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
throw new InvalidProgramException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
}
keyFields.addAll(keys);
}
......
......@@ -807,7 +807,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* This will not create a new DataSet, it will just attach the field names which will be
* used for grouping when executing a grouped operation.
*
* This only works on CaseClass DataSets.
*/
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
new GroupedDataSet[T](
......
......@@ -65,8 +65,6 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O](
* a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def where(firstLeftField: String, otherLeftFields: String*) = {
val leftKey = new ExpressionKeys[L](
......@@ -113,8 +111,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
/**
* Specify the key fields for the right side of the key based operation. This returns
* the finished operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def equalTo(firstRightField: String, otherRightFields: String*): O = {
val rightKey = new ExpressionKeys[R](
......@@ -125,7 +121,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
leftKey + " Right: " + rightKey)
}
unfinished.finish(leftKey, rightKey)
}
/**
......
......@@ -383,5 +383,47 @@ class CoGroupITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo
env.execute()
expectedResult = "-1,20000,Flink\n" + "-1,10000,Flink\n" + "-1,30000,Flink\n"
}
@Test
def testCoGroupWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env)
val ds2 = env.fromElements(0, 1, 2)
val coGroupDs = ds1.coGroup(ds2).where(0).equalTo("*") {
(first, second, out: Collector[(Int, Long, String)]) =>
for (p <- first) {
for (t <- second) {
if (p._1 == t) {
out.collect(p)
}
}
}
}
coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expectedResult = "(1,1,Hi)\n(2,2,Hello)"
}
@Test
def testCoGroupWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env)
val coGroupDs = ds1.coGroup(ds2).where("*").equalTo(0) {
(first, second, out: Collector[(Int, Long, String)]) =>
for (p <- first) {
for (t <- second) {
if (p == t._1) {
out.collect(t)
}
}
}
}
coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expectedResult = "(1,1,Hi)\n(2,2,Hello)"
}
}
......@@ -17,6 +17,9 @@
*/
package org.apache.flink.api.scala.operators
import java.util
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.junit.Assert
import org.junit.Test
......@@ -268,6 +271,60 @@ class CoGroupOperatorTest {
// Should not work, more than one field position key
ds1.coGroup(ds2).where(1, 3).equalTo { _.myLong }
}
@Test
def testCoGroupWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromElements(0, 1, 2)
ds1.coGroup(ds2).where(0).equalTo("*")
}
@Test
def testCoGroupWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = env.fromCollection(emptyTupleData)
ds1.coGroup(ds2).where("*").equalTo(0)
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = env.fromCollection(emptyTupleData)
ds1.coGroup(ds2).where("invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromElements(0, 1, 2)
ds1.coGroup(ds2).where(0).equalTo("invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(new util.ArrayList[Integer]())
val ds2 = env.fromElements(0, 0, 0)
ds1.coGroup(ds2).where("*")
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 0, 0)
val ds2 = env.fromElements(new util.ArrayList[Integer]())
ds1.coGroup(ds2).where("*").equalTo("*")
}
}
......@@ -743,6 +743,15 @@ class GroupReduceITCase(mode: TestExecutionMode) extends MultipleProgramsTestBas
expected = "b\nccc\nee\n"
}
@Test
def testWithAtomic1: Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 1, 2)
val reduceDs = ds.groupBy("*").reduceGroup((ints: Iterator[Int]) => ints.next())
reduceDs.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
expected = "0\n1\n2"
}
}
@RichGroupReduceFunction.Combinable
......
......@@ -17,6 +17,8 @@
*/
package org.apache.flink.api.scala.operators
import java.util
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException
......@@ -96,7 +98,7 @@ class GroupingTest {
}
}
@Test(expected = classOf[IllegalArgumentException])
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
......@@ -146,7 +148,7 @@ class GroupingTest {
}
}
@Test(expected = classOf[IllegalArgumentException])
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyExpressions2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
......@@ -224,5 +226,37 @@ class GroupingTest {
case e: Exception => Assert.fail()
}
}
@Test
def testAtomicValue1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)
ds.groupBy("*")
}
@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)
ds.groupBy("invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)
ds.groupBy("_", "invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(new util.ArrayList[Integer]())
ds.groupBy("*")
}
}
......@@ -384,4 +384,26 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
"2 Second (20,200,2000,Two) 20000,(20000,20,200,2000,Two,2,Second)\n" +
"3 Third (30,300,3000,Three) 30000,(30000,30,300,3000,Three,3,Third)\n"
}
@Test
def testWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env)
val ds2 = env.fromElements(0, 1, 2)
val joinDs = ds1.join(ds2).where(0).equalTo("*")
joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expected = "(1,1,Hi),1\n(2,2,Hello),2"
}
@Test
def testWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env)
val joinDs = ds1.join(ds2).where("*").equalTo(0)
joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expected = "1,(1,1,Hi)\n2,(2,2,Hello)"
}
}
......@@ -17,14 +17,13 @@
*/
package org.apache.flink.api.scala.operators
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException
import org.junit.Ignore
import org.junit.Test
import java.util
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.{Assert, Test}
class JoinOperatorTest {
......@@ -272,5 +271,68 @@ class JoinOperatorTest {
// should not work, more than one field position key
ds1.join(ds2).where(1, 3) equalTo { _.myLong }
}
@Test
def testJoinWithAtomic(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)
ds1.join(ds2).where(1).equalTo("*")
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)
ds1.join(ds2).where(1).equalTo("invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromCollection(emptyTupleData)
ds1.join(ds2).where("invalidKey").equalTo(1)
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)
ds1.join(ds2).where(1).equalTo("_", "invalidKey")
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromCollection(emptyTupleData)
ds1.join(ds2).where("_", "invalidKey").equalTo(1)
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(new util.ArrayList[Integer]())
val ds2 = env.fromCollection(emptyLongData)
ds1.join(ds2).where("*")
}
@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromElements(new util.ArrayList[Integer]())
ds1.join(ds2).where("*").equalTo("*")
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册