提交 2b16c604 编写于 作者: F Fabian Hueske

[FLINK-1112] Additional checks for KeySelector group sorting and minor fixes

上级 f83db149
...@@ -115,7 +115,7 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase { ...@@ -115,7 +115,7 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>()) data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
.withPartitioner(new TestPartitionerInt()) .withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING) .sortGroup(new TestKeySelector<Tuple3<Integer, Integer, Integer>>(), Order.ASCENDING)
.reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>()) .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
.print(); .print();
...@@ -136,39 +136,6 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase { ...@@ -136,39 +136,6 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
} }
} }
@Test
public void testCustomPartitioningKeySelectorGroupReduceSorted2() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0))
.rebalance().setParallelism(4);
data
.groupBy(new TestKeySelector<Tuple4<Integer,Integer,Integer,Integer>>())
.withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.sortGroup(2, Order.DESCENDING)
.reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>())
.print();
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
SinkPlanNode sink = op.getDataSinks().iterator().next();
SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test @Test
public void testCustomPartitioningKeySelectorInvalidType() { public void testCustomPartitioningKeySelectorInvalidType() {
try { try {
......
...@@ -92,7 +92,7 @@ public class SortedGrouping<T> extends Grouping<T> { ...@@ -92,7 +92,7 @@ public class SortedGrouping<T> extends Grouping<T> {
super(set, keys); super(set, keys);
if (!(this.keys instanceof Keys.SelectorFunctionKeys)) { if (!(this.keys instanceof Keys.SelectorFunctionKeys)) {
throw new InvalidProgramException("Sorting on KeySelector only works for KeySelector grouping."); throw new InvalidProgramException("Sorting on KeySelector keys only works with KeySelector grouping.");
} }
this.groupSortKeyPositions = keySelector.computeLogicalKeyPositions(); this.groupSortKeyPositions = keySelector.computeLogicalKeyPositions();
......
...@@ -231,6 +231,10 @@ public class UnsortedGrouping<T> extends Grouping<T> { ...@@ -231,6 +231,10 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order * @see Order
*/ */
public SortedGrouping<T> sortGroup(int field, Order order) { public SortedGrouping<T> sortGroup(int field, Order order) {
if (this.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new InvalidProgramException("KeySelector grouping keys and field index group-sorting keys cannot be used together.");
}
SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order); SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
sg.customPartitioner = getCustomPartitioner(); sg.customPartitioner = getCustomPartitioner();
return sg; return sg;
...@@ -248,6 +252,10 @@ public class UnsortedGrouping<T> extends Grouping<T> { ...@@ -248,6 +252,10 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order * @see Order
*/ */
public SortedGrouping<T> sortGroup(String field, Order order) { public SortedGrouping<T> sortGroup(String field, Order order) {
if (this.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new InvalidProgramException("KeySelector grouping keys and field expression group-sorting keys cannot be used together.");
}
SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order); SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
sg.customPartitioner = getCustomPartitioner(); sg.customPartitioner = getCustomPartitioner();
return sg; return sg;
...@@ -256,7 +264,6 @@ public class UnsortedGrouping<T> extends Grouping<T> { ...@@ -256,7 +264,6 @@ public class UnsortedGrouping<T> extends Grouping<T> {
/** /**
* Sorts elements within a group on a key extracted by the specified {@link org.apache.flink.api.java.functions.KeySelector} * Sorts elements within a group on a key extracted by the specified {@link org.apache.flink.api.java.functions.KeySelector}
* in the specified {@link Order}.</br> * in the specified {@link Order}.</br>
* <b>Note: Only groups of Tuple elements and Pojos can be sorted.</b><br/>
* Chaining {@link #sortGroup(KeySelector, Order)} calls is not supported. * Chaining {@link #sortGroup(KeySelector, Order)} calls is not supported.
* *
* @param keySelector The KeySelector with which the group is sorted. * @param keySelector The KeySelector with which the group is sorted.
...@@ -266,8 +273,14 @@ public class UnsortedGrouping<T> extends Grouping<T> { ...@@ -266,8 +273,14 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order * @see Order
*/ */
public <K> SortedGrouping<T> sortGroup(KeySelector<T, K> keySelector, Order order) { public <K> SortedGrouping<T> sortGroup(KeySelector<T, K> keySelector, Order order) {
if (!(this.getKeys() instanceof Keys.SelectorFunctionKeys)) {
throw new InvalidProgramException("KeySelector group-sorting keys can only be used with KeySelector grouping keys.");
}
TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keySelector, this.dataSet.getType()); TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keySelector, this.dataSet.getType());
return new SortedGrouping<T>(this.dataSet, this.keys, new Keys.SelectorFunctionKeys<T, K>(keySelector, this.dataSet.getType(), keyType), order); SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, new Keys.SelectorFunctionKeys<T, K>(keySelector, this.dataSet.getType(), keyType), order);
sg.customPartitioner = getCustomPartitioner();
return sg;
} }
} }
...@@ -67,6 +67,10 @@ class GroupedDataSet[T: ClassTag]( ...@@ -67,6 +67,10 @@ class GroupedDataSet[T: ClassTag](
if (field >= set.getType.getArity) { if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.") throw new IllegalArgumentException("Order key out of tuple bounds.")
} }
if (keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("KeySelector grouping keys and field index group-sorting " +
"keys cannot be used together.")
}
if (groupSortKeySelector.nonEmpty) { if (groupSortKeySelector.nonEmpty) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not " + throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not " +
"supported.") "supported.")
...@@ -87,6 +91,10 @@ class GroupedDataSet[T: ClassTag]( ...@@ -87,6 +91,10 @@ class GroupedDataSet[T: ClassTag](
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" + throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.") "supported.")
} }
if (keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("KeySelector grouping keys and field expression " +
"group-sorting keys cannot be used together.")
}
groupSortKeyPositions += Right(field) groupSortKeyPositions += Right(field)
groupSortOrders += order groupSortOrders += order
this this
...@@ -103,6 +111,10 @@ class GroupedDataSet[T: ClassTag]( ...@@ -103,6 +111,10 @@ class GroupedDataSet[T: ClassTag](
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" + throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.") "supported.")
} }
if (!keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("Sorting on KeySelector keys only works with KeySelector " +
"grouping.")
}
groupSortOrders += order groupSortOrders += order
val keyType = implicitly[TypeInformation[K]] val keyType = implicitly[TypeInformation[K]]
...@@ -121,7 +133,12 @@ class GroupedDataSet[T: ClassTag]( ...@@ -121,7 +133,12 @@ class GroupedDataSet[T: ClassTag](
private def maybeCreateSortedGrouping(): Grouping[T] = { private def maybeCreateSortedGrouping(): Grouping[T] = {
groupSortKeySelector match { groupSortKeySelector match {
case Some(keySelector) => case Some(keySelector) =>
if (partitioner == null) {
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0)) new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
} else {
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
.withPartitioner(partitioner)
}
case None => case None =>
if (groupSortKeyPositions.length > 0) { if (groupSortKeyPositions.length > 0) {
val grouping = groupSortKeyPositions(0) match { val grouping = groupSortKeyPositions(0) match {
......
...@@ -835,7 +835,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase { ...@@ -835,7 +835,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
.reduceGroup(new Tuple3SortedGroupReduceWithCombine()); .reduceGroup(new Tuple3SortedGroupReduceWithCombine());
reduceDs.writeAsCsv(resultPath); reduceDs.writeAsCsv(resultPath);
reduceDs.print();
env.execute(); env.execute();
// return expected result // return expected result
...@@ -1302,7 +1301,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase { ...@@ -1302,7 +1301,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) { public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) {
int sum = 0; int sum = 0;
long key = 0; long key = 0;
System.out.println("im in");
StringBuilder concat = new StringBuilder(); StringBuilder concat = new StringBuilder();
for (Tuple3<Integer, Long, String> next : values) { for (Tuple3<Integer, Long, String> next : values) {
......
...@@ -25,7 +25,6 @@ import org.apache.flink.api.common.functions.Partitioner ...@@ -25,7 +25,6 @@ import org.apache.flink.api.common.functions.Partitioner
import org.apache.flink.runtime.operators.shipping.ShipStrategyType import org.apache.flink.runtime.operators.shipping.ShipStrategyType
import org.apache.flink.compiler.plan.SingleInputPlanNode import org.apache.flink.compiler.plan.SingleInputPlanNode
import org.apache.flink.test.compiler.util.CompilerTestBase import org.apache.flink.test.compiler.util.CompilerTestBase
import scala.collection.immutable.Seq
import org.apache.flink.api.common.operators.Order import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.InvalidProgramException
...@@ -73,7 +72,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -73,7 +72,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
data data
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt()) .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.reduceGroup( iter => Seq(iter.next()) ) .reduce( (a, b) => a)
.print() .print()
val p = env.createProgramPlan() val p = env.createProgramPlan()
...@@ -81,6 +80,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -81,6 +80,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
val sink = op.getDataSinks.iterator().next() val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode] val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode] val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy) assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
...@@ -96,17 +96,17 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -96,17 +96,17 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
} }
@Test @Test
def testCustomPartitioningKeySelectorGroupReduceSorted() { def testCustomPartitioningIndexGroupReduceSorted() {
try { try {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4) val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
data data
.groupBy( _._1 ) .groupBy(0)
.withPartitioner(new TestPartitionerInt()) .withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING) .sortGroup(1, Order.ASCENDING)
.reduceGroup( iter => Seq(iter.next()) ) .reduce( (a,b) => a)
.print() .print()
val p = env.createProgramPlan() val p = env.createProgramPlan()
...@@ -116,6 +116,41 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -116,6 +116,41 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode] val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode] val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
}
catch {
case e: Exception => {
e.printStackTrace()
fail(e.getMessage)
}
}
}
@Test
def testCustomPartitioningKeySelectorGroupReduceSorted() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
data
.groupBy(_._1)
.withPartitioner(new TestPartitionerInt())
.sortGroup(_._2, Order.ASCENDING)
.reduce( (a,b) => a)
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy) assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy) assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy) assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
...@@ -136,10 +171,10 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -136,10 +171,10 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
val data = env.fromElements( (0,0,0,0) ).rebalance().setParallelism(4) val data = env.fromElements( (0,0,0,0) ).rebalance().setParallelism(4)
data data
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt()) .groupBy(0).withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING) .sortGroup(1, Order.ASCENDING)
.sortGroup(2, Order.DESCENDING) .sortGroup(2, Order.DESCENDING)
.reduceGroup( iter => Seq(iter.next()) ) .reduce( (a,b) => a)
.print() .print()
val p = env.createProgramPlan() val p = env.createProgramPlan()
...@@ -152,6 +187,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase { ...@@ -152,6 +187,7 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy) assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy) assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy) assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
} }
catch { catch {
case e: Exception => { case e: Exception => {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册