提交 0a63797a 编写于 作者: C Chiwan Park 提交者: Fabian Hueske

[FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation.

This closes #1585
上级 572855da
......@@ -1381,6 +1381,24 @@ public abstract class DataSet<T> {
return new SortPartitionOperator<>(this, field, order, Utils.getCallLocationName());
}
/**
* Locally sorts the partitions of the DataSet on the extracted key in the specified order.
* The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
*
* Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
* the partitions by multiple values using KeySelector, the KeySelector must return a tuple
* consisting of the values.
*
* @param keyExtractor The KeySelector function which extracts the key values from the DataSet
* on which the DataSet is sorted.
* @param order The order in which the DataSet is sorted.
* @return The DataSet with sorted local partitions.
*/
public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, getType());
return new SortPartitionOperator<>(this, new Keys.SelectorFunctionKeys<>(clean(keyExtractor), getType(), keyType), order, Utils.getCallLocationName());
}
// --------------------------------------------------------------------------------------------
// Top-K
// --------------------------------------------------------------------------------------------
......
......@@ -26,9 +26,13 @@ import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
/**
* This operator represents a DataSet with locally sorted partitions.
......@@ -38,27 +42,58 @@ import java.util.Arrays;
@Public
public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPartitionOperator<T>> {
private int[] sortKeyPositions;
private List<Keys<T>> keys;
private Order[] sortOrders;
private List<Order> orders;
private final String sortLocationName;
private boolean useKeySelector;
public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) {
private SortPartitionOperator(DataSet<T> dataSet, String sortLocationName) {
super(dataSet, dataSet.getType());
keys = new ArrayList<>();
orders = new ArrayList<>();
this.sortLocationName = sortLocationName;
}
public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) {
this(dataSet, sortLocationName);
this.useKeySelector = false;
ensureSortableKey(sortField);
int[] flatOrderKeys = getFlatFields(sortField);
this.appendSorting(flatOrderKeys, sortOrder);
keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
orders.add(sortOrder);
}
public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrder, String sortLocationName) {
super(dataSet, dataSet.getType());
this.sortLocationName = sortLocationName;
this(dataSet, sortLocationName);
this.useKeySelector = false;
ensureSortableKey(sortField);
keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
orders.add(sortOrder);
}
public <K> SortPartitionOperator(DataSet<T> dataSet, Keys.SelectorFunctionKeys<T, K> sortKey, Order sortOrder, String sortLocationName) {
this(dataSet, sortLocationName);
this.useKeySelector = true;
ensureSortableKey(sortKey);
int[] flatOrderKeys = getFlatFields(sortField);
this.appendSorting(flatOrderKeys, sortOrder);
keys.add(sortKey);
orders.add(sortOrder);
}
/**
* Returns whether using key selector or not.
*/
public boolean useKeySelector() {
return useKeySelector;
}
/**
......@@ -70,9 +105,14 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(int field, Order order) {
if (useKeySelector) {
throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
}
ensureSortableKey(field);
keys.add(new Keys.ExpressionKeys<>(field, getType()));
orders.add(order);
int[] flatOrderKeys = getFlatFields(field);
this.appendSorting(flatOrderKeys, order);
return this;
}
......@@ -81,58 +121,41 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart
* local partition sorting of the DataSet.
*
* @param field The field expression referring to the field of the additional sort order of
* the local partition sorting.
* @param order The order of the additional sort order of the local partition sorting.
* the local partition sorting.
* @param order The order of the additional sort order of the local partition sorting.
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(String field, Order order) {
int[] flatOrderKeys = getFlatFields(field);
this.appendSorting(flatOrderKeys, order);
if (useKeySelector) {
throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
}
ensureSortableKey(field);
keys.add(new Keys.ExpressionKeys<>(field, getType()));
orders.add(order);
return this;
}
// --------------------------------------------------------------------------------------------
// Key Extraction
// --------------------------------------------------------------------------------------------
private int[] getFlatFields(int field) {
public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) {
throw new InvalidProgramException("KeySelector cannot be chained.");
}
if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) {
private void ensureSortableKey(int field) throws InvalidProgramException {
if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, super.getType());
return ek.computeLogicalKeyPositions();
}
private int[] getFlatFields(String fields) {
if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) {
private void ensureSortableKey(String field) throws InvalidProgramException {
if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, super.getType());
return ek.computeLogicalKeyPositions();
}
private void appendSorting(int[] flatOrderFields, Order order) {
if(this.sortKeyPositions == null) {
// set sorting info
this.sortKeyPositions = flatOrderFields;
this.sortOrders = new Order[flatOrderFields.length];
Arrays.fill(this.sortOrders, order);
} else {
// append sorting info to exising info
int oldLength = this.sortKeyPositions.length;
int newLength = oldLength + flatOrderFields.length;
this.sortKeyPositions = Arrays.copyOf(this.sortKeyPositions, newLength);
this.sortOrders = Arrays.copyOf(this.sortOrders, newLength);
for(int i=0; i<flatOrderFields.length; i++) {
this.sortKeyPositions[oldLength+i] = flatOrderFields[i];
this.sortOrders[oldLength+i] = order;
}
private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> sortKey) {
if (!sortKey.getKeyType().isSortKeyType()) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
}
......@@ -144,16 +167,33 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart
String name = "Sort at " + sortLocationName;
if (useKeySelector) {
return translateToDataFlowWithKeyExtractor(input, (Keys.SelectorFunctionKeys<T, ?>) keys.get(0), orders.get(0), name);
}
// flatten sort key positions
List<Integer> allKeyPositions = new ArrayList<>();
List<Order> allOrders = new ArrayList<>();
for (int i = 0, length = keys.size(); i < length; i++) {
int[] sortKeyPositions = keys.get(i).computeLogicalKeyPositions();
Order order = orders.get(i);
for (int sortKeyPosition : sortKeyPositions) {
allKeyPositions.add(sortKeyPosition);
allOrders.add(order);
}
}
Ordering partitionOrdering = new Ordering();
for (int i = 0; i < this.sortKeyPositions.length; i++) {
partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, this.sortOrders[i]);
for (int i = 0, length = allKeyPositions.size(); i < length; i++) {
partitionOrdering.appendOrdering(allKeyPositions.get(i), null, allOrders.get(i));
}
// distinguish between partition types
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<>(getType(), getType());
SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
noop.setInput(input);
if(this.getParallelism() < 0) {
if (this.getParallelism() < 0) {
// use parallelism of input if not explicitly specified
noop.setParallelism(input.getParallelism());
} else {
......@@ -165,4 +205,32 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart
}
private <K> org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> translateToDataFlowWithKeyExtractor(
Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order order, String name) {
TypeInformation<Tuple2<K, T>> typeInfoWithKey = KeyFunctions.createTypeWithKey(keys);
Keys.ExpressionKeys<Tuple2<K, T>> newKey = new Keys.ExpressionKeys<>(0, typeInfoWithKey);
Operator<Tuple2<K, T>> keyedInput = KeyFunctions.appendKeyExtractor(input, keys);
int[] sortKeyPositions = newKey.computeLogicalKeyPositions();
Ordering partitionOrdering = new Ordering();
for (int keyPosition : sortKeyPositions) {
partitionOrdering.appendOrdering(keyPosition, null, order);
}
// distinguish between partition types
UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey);
SortPartitionOperatorBase<Tuple2<K, T>> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
noop.setInput(keyedInput);
if (this.getParallelism() < 0) {
// use parallelism of input if not explicitly specified
noop.setParallelism(input.getParallelism());
} else {
// use explicitly specified parallelism
noop.setParallelism(this.getParallelism());
}
return KeyFunctions.appendKeyRemover(noop, keys);
}
}
......@@ -169,6 +169,88 @@ public class SortPartitionTest {
tupleDs.sortPartition("f3", Order.ASCENDING);
}
@Test
public void testSortPartitionWithKeySelector1() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
// should work
try {
tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Integer>() {
@Override
public Integer getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f0;
}
}, Order.ASCENDING);
} catch (Exception e) {
Assert.fail();
}
}
@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector2() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
// must not work
tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Long[]>() {
@Override
public Long[] getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f3;
}
}, Order.ASCENDING);
}
@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector3() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
// must not work
tupleDs
.sortPartition("f1", Order.ASCENDING)
.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() {
@Override
public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f2;
}
}, Order.ASCENDING);
}
@Test
public void testSortPartitionWithKeySelector4() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
// should work
try {
tupleDs.sortPartition(new KeySelector<Tuple4<Integer,Long,CustomType,Long[]>, Tuple2<Integer, Long>>() {
@Override
public Tuple2<Integer, Long> getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return new Tuple2<>(value.f0, value.f1);
}
}, Order.ASCENDING);
} catch (Exception e) {
Assert.fail();
}
}
@Test(expected = InvalidProgramException.class)
public void testSortPartitionWithKeySelector5() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
// must not work
tupleDs
.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() {
@Override
public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
return value.f2;
}
}, Order.ASCENDING)
.sortPartition("f1", Order.ASCENDING);
}
public static class CustomType implements Serializable {
public static class Nest {
......
......@@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
}
/**
* Locally sorts the partitions of the DataSet on the extracted key in the specified order.
* The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
*
* Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
* the partitions by multiple values using KeySelector, the KeySelector must return a tuple
* consisting of the values.
*/
def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={
val keyExtractor = new KeySelector[T, K] {
val cleanFun = clean(fun)
def getKey(in: T) = cleanFun(in)
}
val keyType = implicitly[TypeInformation[K]]
new PartitionSortedDataSet[T](
new SortPartitionOperator[T](javaSet,
new Keys.SelectorFunctionKeys[T, K](
keyExtractor,
javaSet.getType,
keyType),
order,
getCallLocationName()))
}
// --------------------------------------------------------------------------------------------
// Result writing
// --------------------------------------------------------------------------------------------
......
......@@ -18,7 +18,9 @@
package org.apache.flink.api.scala
import org.apache.flink.annotation.Public
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.SortPartitionOperator
import scala.reflect.ClassTag
......@@ -37,16 +39,30 @@ class PartitionSortedDataSet[T: ClassTag](set: SortPartitionOperator[T])
* Appends the given field and order to the sort-partition operator.
*/
override def sortPartition(field: Int, order: Order): DataSet[T] = {
if (set.useKeySelector()) {
throw new InvalidProgramException("Expression keys cannot be appended after selector " +
"function keys")
}
this.set.sortPartition(field, order)
this
}
/**
* Appends the given field and order to the sort-partition operator.
*/
/**
* Appends the given field and order to the sort-partition operator.
*/
override def sortPartition(field: String, order: Order): DataSet[T] = {
if (set.useKeySelector()) {
throw new InvalidProgramException("Expression keys cannot be appended after selector " +
"function keys")
}
this.set.sortPartition(field, order)
this
}
override def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] = {
throw new InvalidProgramException("KeySelector cannot be chained.")
}
}
......@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
......@@ -197,6 +198,58 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
compareResultAsText(result, expected);
}
@Test
public void testSortPartitionWithKeySelector1() throws Exception {
/*
* Test sort partition on an extracted key
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
List<Tuple1<Boolean>> result = ds
.map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize input
.sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Long>() {
@Override
public Long getKey(Tuple3<Integer, Long, String> value) throws Exception {
return value.f1;
}
}, Order.ASCENDING)
.mapPartition(new OrderCheckMapper<>(new Tuple3AscendingChecker()))
.distinct().collect();
String expected = "(true)\n";
compareResultAsText(result, expected);
}
@Test
public void testSortPartitionWithKeySelector2() throws Exception {
/*
* Test sort partition on an extracted key
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
List<Tuple1<Boolean>> result = ds
.map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize input
.sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Tuple2<Integer, Long>>() {
@Override
public Tuple2<Integer, Long> getKey(Tuple3<Integer, Long, String> value) throws Exception {
return new Tuple2<>(value.f0, value.f1);
}
}, Order.DESCENDING)
.mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
.distinct().collect();
String expected = "(true)\n";
compareResultAsText(result, expected);
}
public interface OrderChecker<T> extends Serializable {
boolean inOrder(T t1, T t2);
}
......@@ -209,6 +262,14 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
}
@SuppressWarnings("serial")
public static class Tuple3AscendingChecker implements OrderChecker<Tuple3<Integer, Long, String>> {
@Override
public boolean inOrder(Tuple3<Integer, Long, String> t1, Tuple3<Integer, Long, String> t2) {
return t1.f1 <= t2.f1;
}
}
@SuppressWarnings("serial")
public static class Tuple5Checker implements OrderChecker<Tuple5<Integer, Long, Integer, String, Long>> {
@Override
......
......@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.flink.api.common.functions.MapPartitionFunction
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
......@@ -166,6 +167,58 @@ class SortPartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestB
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
def testSortPartitionWithKeySelector1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(4)
.sortPartition(_._2, Order.ASCENDING)
.mapPartition(new OrderCheckMapper(new Tuple3AscendingChecker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
def testSortPartitionWithKeySelector2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(4)
.sortPartition(x => (x._2, x._1), Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple3Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test(expected = classOf[InvalidProgramException])
def testSortPartitionWithKeySelector3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
val ds = CollectionDataSets.get3TupleDataSet(env)
val result = ds
.map { x => x }.setParallelism(4)
.sortPartition(x => (x._2, x._1), Order.DESCENDING)
.sortPartition(0, Order.DESCENDING)
.mapPartition(new OrderCheckMapper(new Tuple3Checker))
.distinct()
.collect()
val expected: String = "(true)\n"
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
}
trait OrderChecker[T] extends Serializable {
......@@ -178,6 +231,12 @@ class Tuple3Checker extends OrderChecker[(Int, Long, String)] {
}
}
class Tuple3AscendingChecker extends OrderChecker[(Int, Long, String)] {
def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = {
t1._2 <= t2._2
}
}
class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] {
def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, Long)): Boolean = {
t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册