提交 b689f3fc 编写于 作者: A Aljoscha Krettek

Move DeltaIteration Sanity Check to Base Operators

Is now in JoinOperator and CoGroupOperator. We don't need
the special Scala API sanity check anymore now since they
use java operators that now correctly check in the base
class.
上级 02c08456
......@@ -72,6 +72,24 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
throw new NullPointerException();
}
// sanity check solution set key mismatches
if (input1 instanceof SolutionSetPlaceHolder) {
if (keys1 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys1).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input1).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions.");
}
}
if (input2 instanceof SolutionSetPlaceHolder) {
if (keys2 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys2).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input2).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions.");
}
}
this.keys1 = keys1;
this.keys2 = keys2;
......@@ -466,24 +484,6 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
if (!keys1.areCompatibale(keys2)) {
throw new InvalidProgramException("The pair of co-group keys are not compatible with each other.");
}
// sanity check solution set key mismatches
if (input1 instanceof SolutionSetPlaceHolder) {
if (keys1 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys1).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input1).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions.");
}
}
if (input2 instanceof SolutionSetPlaceHolder) {
if (keys2 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys2).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input2).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions.");
}
}
return new CoGroupOperatorWithoutFunction(keys2);
}
......
......@@ -121,7 +121,25 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
if (keys1 == null || keys2 == null) {
throw new NullPointerException();
}
// sanity check solution set key mismatches
if (input1 instanceof SolutionSetPlaceHolder) {
if (keys1 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys1).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input1).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions.");
}
}
if (input2 instanceof SolutionSetPlaceHolder) {
if (keys2 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys2).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input2).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions.");
}
}
this.keys1 = keys1;
this.keys2 = keys2;
this.joinHint = hint;
......@@ -872,27 +890,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
if (!keys1.areCompatibale(keys2)) {
throw new InvalidProgramException("The pair of join keys are not compatible with each other.");
}
// sanity check solution set key mismatches
if (input1 instanceof SolutionSetPlaceHolder) {
if (keys1 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys1).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input1).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions.");
}
}
if (input2 instanceof SolutionSetPlaceHolder) {
if (keys2 instanceof FieldPositionKeys) {
int[] positions = ((FieldPositionKeys<?>) keys2).computeLogicalKeyPositions();
((SolutionSetPlaceHolder<?>) input2).checkJoinKeyFields(positions);
} else {
throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions.");
}
}
return new DefaultJoin<I1, I2>(input1, input2, keys1, keys2, joinHint);
}
}
......
......@@ -192,8 +192,6 @@ class UnfinishedCoGroupOperation[L: ClassTag, R: ClassTag](
val coGroupOperator = new CoGroupOperator[L, R, (Array[L], Array[R])](
leftInput.javaSet, rightInput.javaSet, leftKey, rightKey, coGrouper, returnType)
DeltaIterationSanityCheck(leftInput, rightInput, leftKey, rightKey)
new CoGroupDataSet(coGroupOperator, leftInput, rightInput, leftKey, rightKey)
}
}
......@@ -203,49 +203,7 @@ class UnfinishedJoinOperation[L, R](
val joinOperator = new EquiJoin[L, R, (L, R)](
leftSet.javaSet, rightSet.javaSet, leftKey, rightKey, joiner, returnType, joinHint)
DeltaIterationSanityCheck(leftSet, rightSet, leftKey, rightKey)
new JoinDataSet(joinOperator, leftSet, rightSet, leftKey, rightKey)
}
}
/**
* This checks whether joining/coGrouping with the DeltaIteration SolutionSet uses the
* same key given when creating the DeltaIteration.
*/
private[flink] object DeltaIterationSanityCheck {
def apply[L, R](
leftSet: DataSet[L],
rightSet: DataSet[R],
leftKey: Keys[L],
rightKey: Keys[R]) = {
// sanity check solution set key mismatches
leftSet.javaSet match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
val positions: Array[Int] = keyFields.computeLogicalKeyPositions
solutionSet.checkJoinKeyFields(positions)
case _ =>
throw new InvalidProgramException("Currently, the solution set may only be joined " +
"with " +
"using tuple field positions.")
}
case _ =>
}
rightSet.javaSet match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
val positions: Array[Int] = keyFields.computeLogicalKeyPositions
solutionSet.checkJoinKeyFields(positions)
case _ =>
throw new InvalidProgramException("Currently, the solution set may only be joined " +
"with " +
"using tuple field positions.")
}
case _ =>
}
}
}
......@@ -21,9 +21,6 @@ package org.apache.flink.api.scala
import org.junit.Test
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala._
import org.scalatest.junit.AssertionsForJUnit
// Verify that the sanity checking in delta iterations works. We just
// have a dummy job that is not meant to be executed. Only verify that
// the join/coGroup inside the iteration is checked.
......@@ -40,7 +37,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test
......@@ -54,7 +51,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test(expected = classOf[InvalidProgramException])
......@@ -68,7 +65,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test(expected = classOf[InvalidProgramException])
......@@ -82,7 +79,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print() }
iteration.print() }
@Test(expected = classOf[InvalidProgramException])
def testIncorrectJoinWithSolution3(): Unit = {
......@@ -95,7 +92,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test
......@@ -109,7 +106,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test
......@@ -123,7 +120,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test(expected = classOf[InvalidProgramException])
......@@ -137,7 +134,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
@Test(expected = classOf[InvalidProgramException])
......@@ -151,7 +148,7 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print() }
iteration.print() }
@Test(expected = classOf[InvalidProgramException])
def testIncorrectCoGroupWithSolution3(): Unit = {
......@@ -164,6 +161,6 @@ class DeltaIterationSanityCheckTest extends Serializable {
(result, ws)
}
val output = iteration.print()
iteration.print()
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册