提交 2b16c604 编写于 作者: F Fabian Hueske

[FLINK-1112] Additional checks for KeySelector group sorting and minor fixes

上级 f83db149
......@@ -115,7 +115,7 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
.withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.sortGroup(new TestKeySelector<Tuple3<Integer, Integer, Integer>>(), Order.ASCENDING)
.reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
.print();
......@@ -136,39 +136,6 @@ public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
}
}
@Test
public void testCustomPartitioningKeySelectorGroupReduceSorted2() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0))
.rebalance().setParallelism(4);
data
.groupBy(new TestKeySelector<Tuple4<Integer,Integer,Integer,Integer>>())
.withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.sortGroup(2, Order.DESCENDING)
.reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>())
.print();
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
SinkPlanNode sink = op.getDataSinks().iterator().next();
SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCustomPartitioningKeySelectorInvalidType() {
try {
......
......@@ -92,7 +92,7 @@ public class SortedGrouping<T> extends Grouping<T> {
super(set, keys);
if (!(this.keys instanceof Keys.SelectorFunctionKeys)) {
throw new InvalidProgramException("Sorting on KeySelector only works for KeySelector grouping.");
throw new InvalidProgramException("Sorting on KeySelector keys only works with KeySelector grouping.");
}
this.groupSortKeyPositions = keySelector.computeLogicalKeyPositions();
......
......@@ -231,6 +231,10 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(int field, Order order) {
if (this.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new InvalidProgramException("KeySelector grouping keys and field index group-sorting keys cannot be used together.");
}
SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
sg.customPartitioner = getCustomPartitioner();
return sg;
......@@ -248,6 +252,10 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(String field, Order order) {
if (this.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new InvalidProgramException("KeySelector grouping keys and field expression group-sorting keys cannot be used together.");
}
SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
sg.customPartitioner = getCustomPartitioner();
return sg;
......@@ -256,7 +264,6 @@ public class UnsortedGrouping<T> extends Grouping<T> {
/**
* Sorts elements within a group on a key extracted by the specified {@link org.apache.flink.api.java.functions.KeySelector}
* in the specified {@link Order}.</br>
* <b>Note: Only groups of Tuple elements and Pojos can be sorted.</b><br/>
* Chaining {@link #sortGroup(KeySelector, Order)} calls is not supported.
*
* @param keySelector The KeySelector with which the group is sorted.
......@@ -266,8 +273,14 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public <K> SortedGrouping<T> sortGroup(KeySelector<T, K> keySelector, Order order) {
if (!(this.getKeys() instanceof Keys.SelectorFunctionKeys)) {
throw new InvalidProgramException("KeySelector group-sorting keys can only be used with KeySelector grouping keys.");
}
TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keySelector, this.dataSet.getType());
return new SortedGrouping<T>(this.dataSet, this.keys, new Keys.SelectorFunctionKeys<T, K>(keySelector, this.dataSet.getType(), keyType), order);
SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, new Keys.SelectorFunctionKeys<T, K>(keySelector, this.dataSet.getType(), keyType), order);
sg.customPartitioner = getCustomPartitioner();
return sg;
}
}
......@@ -67,6 +67,10 @@ class GroupedDataSet[T: ClassTag](
if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.")
}
if (keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("KeySelector grouping keys and field index group-sorting " +
"keys cannot be used together.")
}
if (groupSortKeySelector.nonEmpty) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not " +
"supported.")
......@@ -87,6 +91,10 @@ class GroupedDataSet[T: ClassTag](
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.")
}
if (keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("KeySelector grouping keys and field expression " +
"group-sorting keys cannot be used together.")
}
groupSortKeyPositions += Right(field)
groupSortOrders += order
this
......@@ -103,6 +111,10 @@ class GroupedDataSet[T: ClassTag](
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.")
}
if (!keys.isInstanceOf[Keys.SelectorFunctionKeys[_, _]]) {
throw new InvalidProgramException("Sorting on KeySelector keys only works with KeySelector " +
"grouping.")
}
groupSortOrders += order
val keyType = implicitly[TypeInformation[K]]
......@@ -121,7 +133,12 @@ class GroupedDataSet[T: ClassTag](
private def maybeCreateSortedGrouping(): Grouping[T] = {
groupSortKeySelector match {
case Some(keySelector) =>
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
if (partitioner == null) {
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
} else {
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
.withPartitioner(partitioner)
}
case None =>
if (groupSortKeyPositions.length > 0) {
val grouping = groupSortKeyPositions(0) match {
......
......@@ -835,7 +835,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
.reduceGroup(new Tuple3SortedGroupReduceWithCombine());
reduceDs.writeAsCsv(resultPath);
reduceDs.print();
env.execute();
// return expected result
......@@ -1302,7 +1301,6 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) {
int sum = 0;
long key = 0;
System.out.println("im in");
StringBuilder concat = new StringBuilder();
for (Tuple3<Integer, Long, String> next : values) {
......
......@@ -25,32 +25,31 @@ import org.apache.flink.api.common.functions.Partitioner
import org.apache.flink.runtime.operators.shipping.ShipStrategyType
import org.apache.flink.compiler.plan.SingleInputPlanNode
import org.apache.flink.test.compiler.util.CompilerTestBase
import scala.collection.immutable.Seq
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.InvalidProgramException
class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
@Test
def testCustomPartitioningKeySelectorReduce() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
data
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.reduce( (a,b) => a )
.print()
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.reduce( (a,b) => a )
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val keyRemovingMapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val reducer = keyRemovingMapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
......@@ -63,26 +62,27 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningKeySelectorGroupReduce() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
data
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.reduceGroup( iter => Seq(iter.next()) )
.print()
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.reduce( (a, b) => a)
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
......@@ -94,28 +94,63 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningIndexGroupReduceSorted() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
data
.groupBy(0)
.withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.reduce( (a,b) => a)
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
}
catch {
case e: Exception => {
e.printStackTrace()
fail(e.getMessage)
}
}
}
@Test
def testCustomPartitioningKeySelectorGroupReduceSorted() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
data
.groupBy( _._1 )
.withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.reduceGroup( iter => Seq(iter.next()) )
.print()
.groupBy(_._1)
.withPartitioner(new TestPartitionerInt())
.sortGroup(_._2, Order.ASCENDING)
.reduce( (a,b) => a)
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
......@@ -127,31 +162,32 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningKeySelectorGroupReduceSorted2() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0,0,0,0) ).rebalance().setParallelism(4)
data
.groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.sortGroup(2, Order.DESCENDING)
.reduceGroup( iter => Seq(iter.next()) )
.print()
.groupBy(0).withPartitioner(new TestPartitionerInt())
.sortGroup(1, Order.ASCENDING)
.sortGroup(2, Order.DESCENDING)
.reduce( (a,b) => a)
.print()
val p = env.createProgramPlan()
val op = compileNoStats(p)
val sink = op.getDataSinks.iterator().next()
val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
}
catch {
case e: Exception => {
......@@ -160,22 +196,22 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningKeySelectorInvalidType() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
try {
data
.groupBy( _._1 )
.withPartitioner(new TestPartitionerLong())
.groupBy( _._1 )
.withPartitioner(new TestPartitionerLong())
fail("Should throw an exception")
}
catch {
case e: InvalidProgramException =>
case e: InvalidProgramException =>
}
}
catch {
......@@ -185,23 +221,23 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningKeySelectorInvalidTypeSorted() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
try {
data
.groupBy( _._1 )
.sortGroup(1, Order.ASCENDING)
.withPartitioner(new TestPartitionerLong())
.groupBy( _._1 )
.sortGroup(1, Order.ASCENDING)
.withPartitioner(new TestPartitionerLong())
fail("Should throw an exception")
}
catch {
case e: InvalidProgramException =>
case e: InvalidProgramException =>
}
}
catch {
......@@ -211,20 +247,20 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
@Test
def testCustomPartitioningTupleRejectCompositeKey() {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
try {
data.groupBy( v => (v._1, v._2) ).withPartitioner(new TestPartitionerInt())
fail("Should throw an exception")
}
catch {
case e: InvalidProgramException =>
case e: InvalidProgramException =>
}
}
catch {
......@@ -234,16 +270,16 @@ class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
}
}
}
// ----------------------------------------------------------------------------------------------
private class TestPartitionerInt extends Partitioner[Int] {
override def partition(key: Int, numPartitions: Int): Int = 0
}
private class TestPartitionerLong extends Partitioner[Long] {
override def partition(key: Long, numPartitions: Int): Int = 0
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册