提交 5e40c6d4 编写于 作者: F Fabian Hueske 提交者: Maximilian Michels

[FLINK-2953] fix chaining of sortPartition() calls in Scala DataSet API

- Added tests for Scala DataSet sortPartition

This closes #1317.
上级 95a1abd1
......@@ -1441,7 +1441,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* The DataSet can be sorted on multiple fields by chaining sortPartition() calls.
*/
def sortPartition(field: Int, order: Order): DataSet[T] = {
wrap (new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
new PartitionSortedDataSet[T] (
new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
}
/**
......@@ -1449,7 +1450,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* The DataSet can be sorted on multiple fields by chaining sortPartition() calls.
*/
def sortPartition(field: String, order: Order): DataSet[T] = {
wrap (new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
new PartitionSortedDataSet[T](
new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
}
// --------------------------------------------------------------------------------------------
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.java.operators.SortPartitionOperator
import scala.reflect.ClassTag
/**
* The result of [[DataSet.sortPartition]]. This can be used to append additional sort fields to the
* one sort-partition operator.
*
* @tparam T The type of the DataSet, i.e., the type of the elements of the DataSet.
*/
class PartitionSortedDataSet[T: ClassTag](set: SortPartitionOperator[T])
extends DataSet[T](set) {
/**
* Appends the given field and order to the sort-partition operator.
*/
override def sortPartition(field: Int, order: Order): DataSet[T] = {
this.set.sortPartition(field, order)
this
}
/**
* Appends the given field and order to the sort-partition operator.
*/
override def sortPartition(field: String, order: Order): DataSet[T] = {
this.set.sortPartition(field, order)
this
}
}
......@@ -56,10 +56,10 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(4);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(4) // parallelize input
.sortPartition(1, Order.DESCENDING)
.mapPartition(new OrderCheckMapper<Tuple3<Integer, Long, String>>(new Tuple3Checker()))
.mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -77,11 +77,11 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(2);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds = CollectionDataSets.get5TupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(2) // parallelize input
.sortPartition(4, Order.ASCENDING)
.sortPartition(2, Order.DESCENDING)
.mapPartition(new OrderCheckMapper<Tuple5<Integer, Long, Integer, String, Long>>(new Tuple5Checker()))
.mapPartition(new OrderCheckMapper<>(new Tuple5Checker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -100,10 +100,10 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(4);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(4) // parallelize input
.sortPartition("f1", Order.DESCENDING)
.mapPartition(new OrderCheckMapper<Tuple3<Integer, Long, String>>(new Tuple3Checker()))
.mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -121,11 +121,11 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(2);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds = CollectionDataSets.get5TupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(2) // parallelize input
.sortPartition("f4", Order.ASCENDING)
.sortPartition("f2", Order.DESCENDING)
.mapPartition(new OrderCheckMapper<Tuple5<Integer, Long, Integer, String, Long>>(new Tuple5Checker()))
.mapPartition(new OrderCheckMapper<>(new Tuple5Checker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -143,11 +143,11 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(3);
DataSet<Tuple2<Tuple2<Integer, Integer>, String>> ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(3) // parallelize input
.sortPartition("f0.f1", Order.ASCENDING)
.sortPartition("f1", Order.DESCENDING)
.mapPartition(new OrderCheckMapper<Tuple2<Tuple2<Integer, Integer>, String>>(new NestedTupleChecker()))
.mapPartition(new OrderCheckMapper<>(new NestedTupleChecker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -165,11 +165,11 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(3);
DataSet<POJO> ds = CollectionDataSets.getMixedPojoDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.map(new IdMapper()).setParallelism(1) // parallelize input
.sortPartition("nestedTupleWithCustom.f1.myString", Order.ASCENDING)
.sortPartition("number", Order.DESCENDING)
.mapPartition(new OrderCheckMapper<POJO>(new PojoChecker()))
.mapPartition(new OrderCheckMapper<>(new PojoChecker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -187,9 +187,9 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
env.setParallelism(3);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
List result = ds
List<Tuple1<Boolean>> result = ds
.sortPartition(1, Order.DESCENDING).setParallelism(3) // change parallelism
.mapPartition(new OrderCheckMapper<Tuple3<Integer, Long, String>>(new Tuple3Checker()))
.mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
.distinct().collect();
String expected = "(true)\n";
......@@ -202,6 +202,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
public boolean inOrder(T t1, T t2);
}
@SuppressWarnings("serial")
public static class Tuple3Checker implements OrderChecker<Tuple3<Integer, Long, String>> {
@Override
public boolean inOrder(Tuple3<Integer, Long, String> t1, Tuple3<Integer, Long, String> t2) {
......@@ -209,6 +210,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
}
@SuppressWarnings("serial")
public static class Tuple5Checker implements OrderChecker<Tuple5<Integer, Long, Integer, String, Long>> {
@Override
public boolean inOrder(Tuple5<Integer, Long, Integer, String, Long> t1,
......@@ -217,6 +219,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
}
@SuppressWarnings("serial")
public static class NestedTupleChecker implements OrderChecker<Tuple2<Tuple2<Integer, Integer>, String>> {
@Override
public boolean inOrder(Tuple2<Tuple2<Integer, Integer>, String> t1,
......@@ -226,6 +229,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
}
@SuppressWarnings("serial")
public static class PojoChecker implements OrderChecker<POJO> {
@Override
public boolean inOrder(POJO t1, POJO t2) {
......@@ -235,6 +239,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
}
@SuppressWarnings("unused, serial")
public static class OrderCheckMapper<T> implements MapPartitionFunction<T, Tuple1<Boolean>> {
OrderChecker<T> checker;
......@@ -250,7 +255,7 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
Iterator<T> it = values.iterator();
if(!it.hasNext()) {
out.collect(new Tuple1<Boolean>(true));
out.collect(new Tuple1<>(true));
return;
} else {
T last = it.next();
......@@ -258,17 +263,17 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
while (it.hasNext()) {
T next = it.next();
if (!checker.inOrder(last, next)) {
out.collect(new Tuple1<Boolean>(false));
out.collect(new Tuple1<>(false));
return;
}
last = next;
}
out.collect(new Tuple1<Boolean>(true));
out.collect(new Tuple1<>(true));
}
}
}
@SuppressWarnings("serial")
public static class IdMapper<T> implements MapFunction<T, T> {
@Override
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators
import java.io.Serializable
import java.lang
import scala.collection.JavaConverters._
import org.apache.flink.api.common.functions.MapPartitionFunction
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
import org.apache.flink.util.Collector
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.Test
@RunWith(classOf[Parameterized])
class SortPartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
@Test
@throws(classOf[Exception])
def testSortPartitionByKeyField(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(4)
.sortPartition(1, Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple3Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionByTwoKeyFields(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(2)
val ds = CollectionDataSets.get5TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(2)
.sortPartition(4, Order.ASCENDING)
.sortPartition(2, Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple5Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionByFieldExpression(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(4)
.sortPartition("_2", Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple3Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionByTwoFieldExpressions(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(2)
val ds = CollectionDataSets.get5TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(2)
.sortPartition("_5", Order.ASCENDING)
.sortPartition("_3", Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple5Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionByNestedFieldExpression(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(3)
val ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(3)
.sortPartition("_1._2", Order.ASCENDING)
.sortPartition("_2", Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new NestedTupleChecker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionPojoByNestedFieldExpression(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(3)
val ds = CollectionDataSets.getMixedPojoDataSet(env)
val result = ds
.map { x => x }.setParallelism(3)
.sortPartition("nestedTupleWithCustom._2.myString", Order.ASCENDING)
.sortPartition("number", Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new PojoChecker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
@throws(classOf[Exception])
def testSortPartitionParallelismChange(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(3)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.sortPartition(1, Order.DESCENDING).setParallelism(3)
.mapPartition(new OrderCheckMapper(new Tuple3Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
}
trait OrderChecker[T] extends Serializable {
def inOrder(t1: T, t2: T): Boolean
}
class Tuple3Checker extends OrderChecker[(Int, Long, String)] {
def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = {
t1._2 >= t2._2
}
}
class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] {
def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, Long)): Boolean = {
t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3
}
}
class NestedTupleChecker extends OrderChecker[((Int, Int), String)] {
def inOrder(t1: ((Int, Int), String), t2: ((Int, Int), String)): Boolean = {
t1._1._2 < t2._1._2 || t1._1._2 == t2._1._2 && t1._2.compareTo(t2._2) >= 0
}
}
class PojoChecker extends OrderChecker[CollectionDataSets.POJO] {
def inOrder(t1: CollectionDataSets.POJO, t2: CollectionDataSets.POJO): Boolean = {
t1.nestedTupleWithCustom._2.myString.compareTo(t2.nestedTupleWithCustom._2.myString) < 0 ||
t1.nestedTupleWithCustom._2.myString.compareTo(t2.nestedTupleWithCustom._2.myString) == 0 &&
t1.number >= t2.number
}
}
class OrderCheckMapper[T](checker: OrderChecker[T])
extends MapPartitionFunction[T, Tuple1[Boolean]] {
override def mapPartition(values: lang.Iterable[T], out: Collector[Tuple1[Boolean]]): Unit = {
val it = values.iterator()
if (!it.hasNext) {
out.collect(new Tuple1(true))
}
else {
var last: T = it.next()
while (it.hasNext) {
val next: T = it.next()
if (!checker.inOrder(last, next)) {
out.collect(new Tuple1(false))
return
}
last = next
}
out.collect(new Tuple1(true))
}
}
}
......@@ -271,6 +271,19 @@ object CollectionDataSets {
env.fromCollection(data)
}
def getMixedPojoDataSet(env: ExecutionEnvironment): DataSet[POJO] = {
val data = new mutable.MutableList[POJO]
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10100L))
data.+=(new POJO(2, "First_", 10, 105, 1000L, "One", 10200L))
data.+=(new POJO(3, "First", 11, 102, 3000L, "One", 10200L))
data.+=(new POJO(4, "First_", 11, 106, 1000L, "One", 10300L))
data.+=(new POJO(5, "First", 11, 102, 2000L, "One", 10100L))
data.+=(new POJO(6, "Second_", 20, 200, 2000L, "Two", 10100L))
data.+=(new POJO(7, "Third", 31, 301, 2000L, "Three", 10200L))
data.+=(new POJO(8, "Third_", 30, 300, 1000L, "Three", 10100L))
env.fromCollection(data)
}
def getSmallTuplebasedDataSetMatchingPojo(env: ExecutionEnvironment):
DataSet[(Long, Integer, Integer, Long, String, Integer, String)] = {
val data = new mutable.MutableList[(Long, Integer, Integer, Long, String, Integer, String)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册