diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java index 5063af771e97d47b93d5654e42fd119b04ef67a8..9ecb9a2e2584abf58eee51fe9ad9e3a06396cb1c 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java @@ -72,6 +72,24 @@ public class CoGroupOperator extends TwoInputUdfOperator) keys1).computeLogicalKeyPositions(); + ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); + } else { + throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); + } + } + if (input2 instanceof SolutionSetPlaceHolder) { + if (keys2 instanceof FieldPositionKeys) { + int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); + ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); + } else { + throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); + } + } + this.keys1 = keys1; this.keys2 = keys2; @@ -466,24 +484,6 @@ public class CoGroupOperator extends TwoInputUdfOperator) keys1).computeLogicalKeyPositions(); - ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); - } else { - throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); - } - } - if (input2 instanceof SolutionSetPlaceHolder) { - if (keys2 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); - ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); - } else { - throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); - } - } return new CoGroupOperatorWithoutFunction(keys2); } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java index 704bf1b8089321783743421ef3a5b6fe98651528..bc35c1409a147f257bd511813f64531f0127b2fc 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java @@ -121,7 +121,25 @@ public abstract class JoinOperator extends TwoInputUdfOperator) keys1).computeLogicalKeyPositions(); + ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); + } else { + throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); + } + } + if (input2 instanceof SolutionSetPlaceHolder) { + if (keys2 instanceof FieldPositionKeys) { + int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); + ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); + } else { + throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); + } + } + this.keys1 = keys1; this.keys2 = keys2; this.joinHint = hint; @@ -872,27 +890,7 @@ public abstract class JoinOperator extends TwoInputUdfOperator) keys1).computeLogicalKeyPositions(); - ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); - } else { - throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); - } - } - if (input2 instanceof SolutionSetPlaceHolder) { - if (keys2 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); - ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); - } else { - throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); - } - } - - + return new DefaultJoin(input1, input2, keys1, keys2, joinHint); } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala index 5c1d3333091893f0207eb4eb16a10bcdce9d34bb..cc3d836faea1ba0684e0efabefdba46488aa414e 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala @@ -192,8 +192,6 @@ class UnfinishedCoGroupOperation[L: ClassTag, R: ClassTag]( val coGroupOperator = new CoGroupOperator[L, R, (Array[L], Array[R])]( leftInput.javaSet, rightInput.javaSet, leftKey, rightKey, coGrouper, returnType) - DeltaIterationSanityCheck(leftInput, rightInput, leftKey, rightKey) - new CoGroupDataSet(coGroupOperator, leftInput, rightInput, leftKey, rightKey) } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala index 3dca756425148fc4156828ab8caef1a1ee4c2b4a..d333a66807ac11a88275f5f65dc501892949e1c5 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala @@ -203,49 +203,7 @@ class UnfinishedJoinOperation[L, R]( val joinOperator = new EquiJoin[L, R, (L, R)]( leftSet.javaSet, rightSet.javaSet, leftKey, rightKey, joiner, returnType, joinHint) - DeltaIterationSanityCheck(leftSet, rightSet, leftKey, rightKey) - new JoinDataSet(joinOperator, leftSet, rightSet, leftKey, rightKey) } } - -/** - * This checks whether joining/coGrouping with the DeltaIteration SolutionSet uses the - * same key given when creating the DeltaIteration. - */ -private[flink] object DeltaIterationSanityCheck { - def apply[L, R]( - leftSet: DataSet[L], - rightSet: DataSet[R], - leftKey: Keys[L], - rightKey: Keys[R]) = { - // sanity check solution set key mismatches - leftSet.javaSet match { - case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] => - leftKey match { - case keyFields: Keys.FieldPositionKeys[_] => - val positions: Array[Int] = keyFields.computeLogicalKeyPositions - solutionSet.checkJoinKeyFields(positions) - case _ => - throw new InvalidProgramException("Currently, the solution set may only be joined " + - "with " + - "using tuple field positions.") - } - case _ => - } - rightSet.javaSet match { - case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] => - rightKey match { - case keyFields: Keys.FieldPositionKeys[_] => - val positions: Array[Int] = keyFields.computeLogicalKeyPositions - solutionSet.checkJoinKeyFields(positions) - case _ => - throw new InvalidProgramException("Currently, the solution set may only be joined " + - "with " + - "using tuple field positions.") - } - case _ => - } - } -} diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/DeltaIterationSanityCheckTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/DeltaIterationSanityCheckTest.scala index f809cb485fdd0819b3b178a44c0c2c4b798063e7..729fda4183f99070a5ed646864fece17e43f358c 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/DeltaIterationSanityCheckTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/DeltaIterationSanityCheckTest.scala @@ -21,9 +21,6 @@ package org.apache.flink.api.scala import org.junit.Test import org.apache.flink.api.common.InvalidProgramException -import org.apache.flink.api.scala._ -import org.scalatest.junit.AssertionsForJUnit - // Verify that the sanity checking in delta iterations works. We just // have a dummy job that is not meant to be executed. Only verify that // the join/coGroup inside the iteration is checked. @@ -40,7 +37,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test @@ -54,7 +51,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test(expected = classOf[InvalidProgramException]) @@ -68,7 +65,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test(expected = classOf[InvalidProgramException]) @@ -82,7 +79,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() } + iteration.print() } @Test(expected = classOf[InvalidProgramException]) def testIncorrectJoinWithSolution3(): Unit = { @@ -95,7 +92,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test @@ -109,7 +106,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test @@ -123,7 +120,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test(expected = classOf[InvalidProgramException]) @@ -137,7 +134,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } @Test(expected = classOf[InvalidProgramException]) @@ -151,7 +148,7 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() } + iteration.print() } @Test(expected = classOf[InvalidProgramException]) def testIncorrectCoGroupWithSolution3(): Unit = { @@ -164,6 +161,6 @@ class DeltaIterationSanityCheckTest extends Serializable { (result, ws) } - val output = iteration.print() + iteration.print() } }