提交 f83db149 编写于 作者: M mingliang 提交者: Fabian Hueske

[FLINK-1112] Add KeySelector group sorting on KeySelector grouping

This closes #209
上级 ba7a19c1
......@@ -31,7 +31,10 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.operators.translation.KeyExtractingMapper;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingSortedReduceGroupOperator;
import org.apache.flink.api.java.operators.translation.TwoKeyExtractingMapper;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.DataSet;
......@@ -144,14 +147,35 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<IN, ?> selectorKeys = (Keys.SelectorFunctionKeys<IN, ?>) grouper.getKeys();
PlanUnwrappingReduceGroupOperator<IN, OUT, ?> po = translateSelectorFunctionReducer(
if (grouper instanceof SortedGrouping) {
SortedGrouping<IN> sortedGrouper = (SortedGrouping<IN>) grouper;
Keys.SelectorFunctionKeys<IN, ?> sortKeys = sortedGrouper.getSortSelectionFunctionKey();
PlanUnwrappingSortedReduceGroupOperator<IN, OUT, ?, ?> po = translateSelectorFunctionSortedReducer(
selectorKeys, sortKeys, function, getInputType(), getResultType(), name, input, isCombinable());
// set group order
int[] sortKeyPositions = sortedGrouper.getGroupSortKeyPositions();
Order[] sortOrders = sortedGrouper.getGroupSortOrders();
Ordering o = new Ordering();
for(int i=0; i < sortKeyPositions.length; i++) {
o.appendOrdering(sortKeyPositions[i], null, sortOrders[i]);
}
po.setGroupOrder(o);
po.setDegreeOfParallelism(this.getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
} else {
PlanUnwrappingReduceGroupOperator<IN, OUT, ?> po = translateSelectorFunctionReducer(
selectorKeys, function, getInputType(), getResultType(), name, input, isCombinable());
po.setDegreeOfParallelism(getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
po.setDegreeOfParallelism(this.getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
}
}
else if (grouper.getKeys() instanceof Keys.ExpressionKeys) {
......@@ -213,4 +237,32 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
return reducer;
}
private static <IN, OUT, K1, K2> PlanUnwrappingSortedReduceGroupOperator<IN, OUT, K1, K2> translateSelectorFunctionSortedReducer(
Keys.SelectorFunctionKeys<IN, ?> rawGroupingKey, Keys.SelectorFunctionKeys<IN, ?> rawSortingKey, GroupReduceFunction<IN, OUT> function,
TypeInformation<IN> inputType, TypeInformation<OUT> outputType, String name, Operator<IN> input,
boolean combinable)
{
@SuppressWarnings("unchecked")
final Keys.SelectorFunctionKeys<IN, K1> groupingKey = (Keys.SelectorFunctionKeys<IN, K1>) rawGroupingKey;
@SuppressWarnings("unchecked")
final Keys.SelectorFunctionKeys<IN, K2> sortingKey = (Keys.SelectorFunctionKeys<IN, K2>) rawSortingKey;
TypeInformation<Tuple3<K1, K2, IN>> typeInfoWithKey = new TupleTypeInfo<Tuple3<K1, K2, IN>>(groupingKey.getKeyType(), sortingKey.getKeyType(), inputType);
TwoKeyExtractingMapper<IN, K1, K2> extractor = new TwoKeyExtractingMapper<IN, K1, K2>(groupingKey.getKeyExtractor(), sortingKey.getKeyExtractor());
PlanUnwrappingSortedReduceGroupOperator<IN, OUT, K1, K2> reducer = new PlanUnwrappingSortedReduceGroupOperator<IN, OUT, K1, K2>(function, groupingKey, sortingKey, name, outputType, typeInfoWithKey, combinable);
MapOperatorBase<IN, Tuple3<K1, K2, IN>, MapFunction<IN, Tuple3<K1, K2, IN>>> mapper = new MapOperatorBase<IN, Tuple3<K1, K2, IN>, MapFunction<IN, Tuple3<K1, K2, IN>>>(extractor, new UnaryOperatorInformation<IN, Tuple3<K1, K2, IN>>(inputType, typeInfoWithKey), "Key Extractor");
reducer.setInput(mapper);
mapper.setInput(input);
// set the mapper's parallelism to the input parallelism to make sure it is chained
mapper.setDegreeOfParallelism(input.getDegreeOfParallelism());
return reducer;
}
}
......@@ -47,7 +47,8 @@ import com.google.common.base.Preconditions;
public class SortedGrouping<T> extends Grouping<T> {
private int[] groupSortKeyPositions;
private Order[] groupSortOrders ;
private Order[] groupSortOrders;
private Keys.SelectorFunctionKeys<T, ?> groupSortSelectorFunctionKey = null;
/*
* int sorting keys for tuples
......@@ -83,6 +84,26 @@ public class SortedGrouping<T> extends Grouping<T> {
this.groupSortOrders = new Order[groupSortKeyPositions.length];
Arrays.fill(this.groupSortOrders, order); // if field == "*"
}
/*
* KeySelector sorting for any data type
*/
public <K> SortedGrouping(DataSet<T> set, Keys<T> keys, Keys.SelectorFunctionKeys<T, K> keySelector, Order order) {
super(set, keys);
if (!(this.keys instanceof Keys.SelectorFunctionKeys)) {
throw new InvalidProgramException("Sorting on KeySelector only works for KeySelector grouping.");
}
this.groupSortKeyPositions = keySelector.computeLogicalKeyPositions();
for (int i = 0; i < groupSortKeyPositions.length; i++) {
groupSortKeyPositions[i] += this.keys.getNumberOfKeyFields();
}
this.groupSortSelectorFunctionKey = keySelector;
this.groupSortOrders = new Order[groupSortKeyPositions.length];
Arrays.fill(this.groupSortOrders, order);
}
// --------------------------------------------------------------------------------------------
......@@ -109,6 +130,10 @@ public class SortedGrouping<T> extends Grouping<T> {
return this;
}
protected Keys.SelectorFunctionKeys<T, ?> getSortSelectionFunctionKey() {
return this.groupSortSelectorFunctionKey;
}
/**
* Applies a GroupReduce transformation on a grouped and sorted {@link DataSet}.<br/>
* The transformation calls a {@link org.apache.flink.api.common.functions.RichGroupReduceFunction} for each group of the DataSet.
......@@ -162,7 +187,9 @@ public class SortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(int field, Order order) {
if (groupSortSelectorFunctionKey != null) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not supported");
}
if (!dataSet.getType().isTupleType()) {
throw new InvalidProgramException("Specifying order keys via field positions is only valid for tuple data types");
}
......@@ -202,7 +229,9 @@ public class SortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(String field, Order order) {
if (groupSortSelectorFunctionKey != null) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not supported");
}
if (! (dataSet.getType() instanceof CompositeType)) {
throw new InvalidProgramException("Specifying order keys via field positions is only valid for composite data types (pojo / tuple / case class)");
}
......@@ -210,5 +239,5 @@ public class SortedGrouping<T> extends Grouping<T> {
addSortGroupInternal(ek, order);
return this;
}
}
......@@ -28,6 +28,7 @@ import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FirstReducer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SelectByMaxFunction;
import org.apache.flink.api.java.functions.SelectByMinFunction;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
......@@ -251,5 +252,22 @@ public class UnsortedGrouping<T> extends Grouping<T> {
sg.customPartitioner = getCustomPartitioner();
return sg;
}
/**
* Sorts elements within a group on a key extracted by the specified {@link org.apache.flink.api.java.functions.KeySelector}
* in the specified {@link Order}.</br>
* <b>Note: Only groups of Tuple elements and Pojos can be sorted.</b><br/>
* Chaining {@link #sortGroup(KeySelector, Order)} calls is not supported.
*
* @param keySelector The KeySelector with which the group is sorted.
* @param order The Order in which the extracted key is sorted.
* @return A SortedGrouping with specified order of group element.
*
* @see Order
*/
public <K> SortedGrouping<T> sortGroup(KeySelector<T, K> keySelector, Order order) {
TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keySelector, this.dataSet.getType());
return new SortedGrouping<T>(this.dataSet, this.keys, new Keys.SelectorFunctionKeys<T, K>(keySelector, this.dataSet.getType(), keyType), order);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.translation;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.Collector;
/**
* A reduce operator that takes 3-tuples (groupKey, sortKey, value), and applies the sorted group reduce
* operation only on the unwrapped values.
*/
public class PlanUnwrappingSortedReduceGroupOperator<IN, OUT, K1, K2> extends GroupReduceOperatorBase<Tuple3<K1, K2, IN>, OUT, GroupReduceFunction<Tuple3<K1, K2, IN>,OUT>> {
public PlanUnwrappingSortedReduceGroupOperator(GroupReduceFunction<IN, OUT> udf, Keys.SelectorFunctionKeys<IN, K1> groupingKey, Keys.SelectorFunctionKeys<IN, K2> sortingKey, String name,
TypeInformation<OUT> outType, TypeInformation<Tuple3<K1, K2, IN>> typeInfoWithKey, boolean combinable)
{
super(combinable ? new TupleUnwrappingFlatCombinableGroupReducer<IN, OUT, K1, K2>((RichGroupReduceFunction<IN, OUT>) udf) : new TupleUnwrappingNonCombinableGroupReducer<IN, OUT, K1, K2>(udf),
new UnaryOperatorInformation<Tuple3<K1, K2, IN>, OUT>(typeInfoWithKey, outType), groupingKey.computeLogicalKeyPositions(), name);
super.setCombinable(combinable);
}
// --------------------------------------------------------------------------------------------
@RichGroupReduceFunction.Combinable
public static final class TupleUnwrappingFlatCombinableGroupReducer<IN, OUT, K1, K2> extends WrappingFunction<RichGroupReduceFunction<IN, OUT>>
implements GroupReduceFunction<Tuple3<K1, K2, IN>, OUT>, FlatCombineFunction<Tuple3<K1, K2, IN>>
{
private static final long serialVersionUID = 1L;
private Tuple3UnwrappingIterator<IN, K1, K2> iter;
private Tuple3WrappingCollector<IN, K1, K2> coll;
private TupleUnwrappingFlatCombinableGroupReducer(RichGroupReduceFunction<IN, OUT> wrapped) {
super(wrapped);
this.iter = new Tuple3UnwrappingIterator<IN, K1, K2>();
this.coll = new Tuple3WrappingCollector<IN, K1, K2>(this.iter);
}
@Override
public void reduce(Iterable<Tuple3<K1, K2, IN>> values, Collector<OUT> out) throws Exception {
iter.set(values.iterator());
this.wrappedFunction.reduce(iter, out);
}
@Override
public void combine(Iterable<Tuple3<K1, K2, IN>> values, Collector<Tuple3<K1, K2, IN>> out) throws Exception {
iter.set(values.iterator());
coll.set(out);
this.wrappedFunction.combine(iter, coll);
}
@Override
public String toString() {
return this.wrappedFunction.toString();
}
}
public static final class TupleUnwrappingNonCombinableGroupReducer<IN, OUT, K1, K2> extends WrappingFunction<GroupReduceFunction<IN, OUT>>
implements GroupReduceFunction<Tuple3<K1, K2, IN>, OUT>
{
private static final long serialVersionUID = 1L;
private final Tuple3UnwrappingIterator<IN, K1, K2> iter;
private TupleUnwrappingNonCombinableGroupReducer(GroupReduceFunction<IN, OUT> wrapped) {
super(wrapped);
this.iter = new Tuple3UnwrappingIterator<IN, K1, K2>();
}
@Override
public void reduce(Iterable<Tuple3<K1, K2, IN>> values, Collector<OUT> out) throws Exception {
iter.set(values.iterator());
this.wrappedFunction.reduce(iter, out);
}
@Override
public String toString() {
return this.wrappedFunction.toString();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.translation;
import java.util.Iterator;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.TraversableOnceException;
/**
* An iterator that reads 3-tuples (groupKey, sortKey, value) and returns only the values (thrid field).
* The iterator also tracks the groupKeys, as the triples flow though it.
*/
public class Tuple3UnwrappingIterator<T, K1, K2> implements Iterator<T>, Iterable<T>, java.io.Serializable {
private static final long serialVersionUID = 1L;
private K1 lastGroupKey;
private K2 lastSortKey;
private Iterator<Tuple3<K1, K2, T>> iterator;
private boolean iteratorAvailable;
public void set(Iterator<Tuple3<K1, K2, T>> iterator) {
this.iterator = iterator;
this.iteratorAvailable = true;
}
public K1 getLastGroupKey() {
return lastGroupKey;
}
public K2 getLastSortKey() {
return lastSortKey;
}
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public T next() {
Tuple3<K1, K2, T> t = iterator.next();
this.lastGroupKey = t.f0;
this.lastSortKey = t.f1;
return t.f2;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public Iterator<T> iterator() {
if (iteratorAvailable) {
iteratorAvailable = false;
return this;
} else {
throw new TraversableOnceException();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.translation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.Collector;
/**
* Needed to wrap tuples to Tuple3<groupKey, sortKey, value> for combine method of group reduce with key selector sorting
*/
public class Tuple3WrappingCollector<IN, K1, K2> implements Collector<IN>, java.io.Serializable {
private static final long serialVersionUID = 1L;
private final Tuple3UnwrappingIterator<IN, K1, K2> tui;
private final Tuple3<K1, K2, IN> outTuple;
private Collector<Tuple3<K1, K2, IN>> wrappedCollector;
public Tuple3WrappingCollector(Tuple3UnwrappingIterator<IN, K1, K2> tui) {
this.tui = tui;
this.outTuple = new Tuple3<K1, K2, IN>();
}
public void set(Collector<Tuple3<K1, K2, IN>> wrappedCollector) {
this.wrappedCollector = wrappedCollector;
}
@Override
public void close() {
this.wrappedCollector.close();
}
@Override
public void collect(IN record) {
this.outTuple.f0 = this.tui.getLastGroupKey();
this.outTuple.f1 = this.tui.getLastSortKey();
this.outTuple.f2 = record;
this.wrappedCollector.collect(outTuple);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.translation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple3;
public final class TwoKeyExtractingMapper<T, K1, K2> extends RichMapFunction<T, Tuple3<K1, K2, T>> {
private static final long serialVersionUID = 1L;
private final KeySelector<T, K1> keySelector1;
private final KeySelector<T, K2> keySelector2;
private final Tuple3<K1, K2, T> tuple = new Tuple3<K1, K2, T>();
public TwoKeyExtractingMapper(KeySelector<T, K1> keySelector1, KeySelector<T, K2> keySelector2) {
this.keySelector1 = keySelector1;
this.keySelector2 = keySelector2;
}
@Override
public Tuple3<K1, K2, T> map(T value) throws Exception {
K1 key1 = keySelector1.getKey(value);
K2 key2 = keySelector2.getKey(value);
tuple.f0 = key1;
tuple.f1 = key2;
tuple.f2 = value;
return tuple;
}
}
......@@ -18,7 +18,7 @@
package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.functions.FirstReducer
import org.apache.flink.api.java.functions.{KeySelector, FirstReducer}
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
import scala.collection.JavaConverters._
import org.apache.commons.lang3.Validate
......@@ -48,9 +48,11 @@ class GroupedDataSet[T: ClassTag](
// when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Either[Int, String]]()
private val groupSortOrders = mutable.MutableList[Order]()
private var partitioner : Partitioner[_] = _
private var groupSortKeySelector: Option[Keys.SelectorFunctionKeys[T, _]] = None
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
......@@ -65,6 +67,10 @@ class GroupedDataSet[T: ClassTag](
if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.")
}
if (groupSortKeySelector.nonEmpty) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not " +
"supported.")
}
groupSortKeyPositions += Left(field)
groupSortOrders += order
this
......@@ -77,55 +83,88 @@ class GroupedDataSet[T: ClassTag](
* This only works on CaseClass DataSets.
*/
def sortGroup(field: String, order: Order): GroupedDataSet[T] = {
if (groupSortKeySelector.nonEmpty) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.")
}
groupSortKeyPositions += Right(field)
groupSortOrders += order
this
}
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This works on any data type.
*/
def sortGroup[K: TypeInformation](fun: T => K, order: Order): GroupedDataSet[T] = {
if (groupSortOrders.length != 0) {
throw new InvalidProgramException("Chaining sortGroup with KeySelector sorting is not" +
"supported.")
}
groupSortOrders += order
val keyType = implicitly[TypeInformation[K]]
val keyExtractor = new KeySelector[T, K] {
def getKey(in: T) = fun(in)
}
groupSortKeySelector = Some(new Keys.SelectorFunctionKeys[T, K](
keyExtractor,
set.javaSet.getType,
keyType))
this
}
/**
* Creates a [[SortedGrouping]] if group sorting keys were specified.
*/
private def maybeCreateSortedGrouping(): Grouping[T] = {
if (groupSortKeyPositions.length > 0) {
val grouping = groupSortKeyPositions(0) match {
case Left(pos) =>
new SortedGrouping[T](
set.javaSet,
keys,
pos,
groupSortOrders(0))
case Right(field) =>
new SortedGrouping[T](
set.javaSet,
keys,
field,
groupSortOrders(0))
groupSortKeySelector match {
case Some(keySelector) =>
new SortedGrouping[T](set.javaSet, keys, keySelector, groupSortOrders(0))
case None =>
if (groupSortKeyPositions.length > 0) {
val grouping = groupSortKeyPositions(0) match {
case Left(pos) =>
new SortedGrouping[T](
set.javaSet,
keys,
pos,
groupSortOrders(0))
}
// now manually add the rest of the keys
for (i <- 1 until groupSortKeyPositions.length) {
groupSortKeyPositions(i) match {
case Left(pos) =>
grouping.sortGroup(pos, groupSortOrders(i))
case Right(field) =>
new SortedGrouping[T](
set.javaSet,
keys,
field,
groupSortOrders(0))
case Right(field) =>
grouping.sortGroup(field, groupSortOrders(i))
}
// now manually add the rest of the keys
for (i <- 1 until groupSortKeyPositions.length) {
groupSortKeyPositions(i) match {
case Left(pos) =>
grouping.sortGroup(pos, groupSortOrders(i))
case Right(field) =>
grouping.sortGroup(field, groupSortOrders(i))
}
}
if (partitioner == null) {
grouping
} else {
grouping.withPartitioner(partitioner)
}
}
}
if (partitioner == null) {
grouping
} else {
grouping.withPartitioner(partitioner)
createUnsortedGrouping()
}
} else {
createUnsortedGrouping()
}
}
/** Convenience methods for creating the [[UnsortedGrouping]] */
private def createUnsortedGrouping(): Grouping[T] = {
val grp = new UnsortedGrouping[T](set.javaSet, keys)
......
......@@ -26,6 +26,7 @@ import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
......@@ -699,6 +700,192 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
"2---(30,600)-(30,400)-(30,200)-(20,201)-(20,200)-\n";
}
@Test
public void testTupleKeySelectorGroupSort() throws Exception {
/*
* check correctness of sorted groupReduce on tuples with keyselector sorting
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(1);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> reduceDs = ds
.groupBy(new LongFieldExtractor<Tuple3<Integer, Long, String>>(1))
.sortGroup(new StringFieldExtractor<Tuple3<Integer, Long, String>>(2), Order.DESCENDING)
.reduceGroup(new Tuple3SortedGroupReduce());
reduceDs.writeAsCsv(resultPath);
env.execute();
// return expected result
expected = "1,1,Hi\n" +
"5,2,Hello world-Hello\n" +
"15,3,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"65,5,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"111,6,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n";
}
public static class TwoTuplePojoExtractor implements KeySelector<CustomType, Tuple2<Integer, Integer>> {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer, Integer> getKey(CustomType value) throws Exception {
return new Tuple2<Integer, Integer>(value.myInt, value.myInt);
}
}
public static class StringPojoExtractor implements KeySelector<CustomType, String> {
private static final long serialVersionUID = 1L;
@Override
public String getKey(CustomType value) throws Exception {
return value.myString;
}
}
@Test
public void testPojoKeySelectorGroupSort() throws Exception {
/*
* check correctness of sorted groupReduce on custom type with keyselector sorting
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<CustomType> ds = CollectionDataSets.getCustomTypeDataSet(env);
DataSet<CustomType> reduceDs = ds
.groupBy(new TwoTuplePojoExtractor())
.sortGroup(new StringPojoExtractor(), Order.DESCENDING)
.reduceGroup(new CustomTypeSortedGroupReduce());
reduceDs.writeAsText(resultPath);
env.execute();
// return expected result
expected = "1,0,Hi\n" +
"2,3,Hello world-Hello\n" +
"3,12,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"4,30,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"5,60,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"6,105,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n";
}
public static class LongFieldExtractor<T extends Tuple> implements KeySelector<T, Long> {
private static final long serialVersionUID = 1L;
private int field;
public LongFieldExtractor() { }
public LongFieldExtractor(int field) {
this.field = field;
}
@Override
public Long getKey(T t) throws Exception {
return ((Tuple)t).getField(field);
}
}
public static class IntFieldExtractor<T extends Tuple> implements KeySelector<T, Integer> {
private static final long serialVersionUID = 1L;
private int field;
public IntFieldExtractor() { }
public IntFieldExtractor(int field) {
this.field = field;
}
@Override
public Integer getKey(T t) throws Exception {
return ((Tuple)t).getField(field);
}
}
public static class StringFieldExtractor<T extends Tuple> implements KeySelector<T, String> {
private static final long serialVersionUID = 1L;
private int field;
public StringFieldExtractor() { }
public StringFieldExtractor(int field) {
this.field = field;
}
@Override
public String getKey(T t) throws Exception {
return ((Tuple)t).getField(field);
}
}
@Test
public void testTupleKeySelectorSortWithCombine() throws Exception {
/*
* check correctness of sorted groupReduce with combine on tuples with keyselector sorting
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(1);
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
DataSet<Tuple2<Integer, String>> reduceDs = ds.
groupBy(new LongFieldExtractor<Tuple3<Integer, Long, String>>(1))
.sortGroup(new StringFieldExtractor<Tuple3<Integer, Long, String>>(2), Order.DESCENDING)
.reduceGroup(new Tuple3SortedGroupReduceWithCombine());
reduceDs.writeAsCsv(resultPath);
reduceDs.print();
env.execute();
// return expected result
if (super.mode == ExecutionMode.COLLECTION) {
expected = null;
} else {
expected = "1,Hi\n" +
"5,Hello world-Hello\n" +
"15,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"34,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"65,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"111,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n";
}
}
public static class FiveToTwoTupleExtractor implements KeySelector<Tuple5<Integer, Long, Integer, String, Long>, Tuple2<Long, Integer>> {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Long, Integer> getKey(Tuple5<Integer, Long, Integer, String, Long> in) {
return new Tuple2<Long, Integer>(in.f4, in.f2);
}
}
@Test
public void testTupleKeySelectorSortCombineOnTuple() throws Exception {
/*
* check correctness of sorted groupReduceon with Tuple2 keyselector sorting
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(1);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds = CollectionDataSets.get5TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> reduceDs = ds
.groupBy(new IntFieldExtractor<Tuple5<Integer, Long, Integer, String, Long>>(0))
.sortGroup(new FiveToTwoTupleExtractor(), Order.DESCENDING)
.reduceGroup(new Tuple5SortedGroupReduce());
reduceDs.writeAsCsv(resultPath);
env.execute();
// return expected result
expected = "1,1,0,Hallo,1\n" +
"2,5,0,Hallo Welt-Hallo Welt wie,1\n" +
"3,15,0,BCD-ABC-Hallo Welt wie gehts?,2\n" +
"4,34,0,FGH-CDE-EFG-DEF,1\n" +
"5,65,0,IJK-HIJ-KLM-JKL-GHI,1\n";
}
public static class GroupReducer5 implements GroupReduceFunction<CollectionDataSets.PojoContainingTupleAndWritable, String> {
@Override
public void reduce(
......@@ -905,6 +1092,33 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
out.collect(new Tuple5<Integer, Long, Integer, String, Long>(i, l, 0, "P-)", l2));
}
}
public static class Tuple5SortedGroupReduce implements GroupReduceFunction<Tuple5<Integer, Long, Integer, String, Long>, Tuple5<Integer, Long, Integer, String, Long>> {
private static final long serialVersionUID = 1L;
@Override
public void reduce(
Iterable<Tuple5<Integer, Long, Integer, String, Long>> values,
Collector<Tuple5<Integer, Long, Integer, String, Long>> out)
{
int i = 0;
long l = 0l;
long l2 = 0l;
StringBuilder concat = new StringBuilder();
for ( Tuple5<Integer, Long, Integer, String, Long> t : values ) {
i = t.f0;
l += t.f1;
concat.append(t.f3).append("-");
l2 = t.f4;
}
if (concat.length() > 0) {
concat.setLength(concat.length() - 1);
}
out.collect(new Tuple5<Integer, Long, Integer, String, Long>(i, l, 0, concat.toString(), l2));
}
}
public static class CustomTypeGroupReduce implements GroupReduceFunction<CustomType, CustomType> {
private static final long serialVersionUID = 1L;
......@@ -931,6 +1145,33 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
}
}
public static class CustomTypeSortedGroupReduce implements GroupReduceFunction<CustomType, CustomType> {
private static final long serialVersionUID = 1L;
@Override
public void reduce(Iterable<CustomType> values, Collector<CustomType> out) {
final Iterator<CustomType> iter = values.iterator();
CustomType o = new CustomType();
CustomType c = iter.next();
StringBuilder concat = new StringBuilder(c.myString);
o.myInt = c.myInt;
o.myLong = c.myLong;
while (iter.hasNext()) {
CustomType next = iter.next();
concat.append("-").append(next.myString);
o.myLong += next.myLong;
}
o.myString = concat.toString();
out.collect(o);
}
}
public static class InputReturningTuple3GroupReduce implements GroupReduceFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {
private static final long serialVersionUID = 1L;
......@@ -1052,6 +1293,44 @@ public class GroupReduceITCase extends MultipleProgramsTestBase {
}
}
@RichGroupReduceFunction.Combinable
public static class Tuple3SortedGroupReduceWithCombine extends RichGroupReduceFunction<Tuple3<Integer, Long, String>, Tuple2<Integer, String>> {
private static final long serialVersionUID = 1L;
@Override
public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) {
int sum = 0;
long key = 0;
System.out.println("im in");
StringBuilder concat = new StringBuilder();
for (Tuple3<Integer, Long, String> next : values) {
sum += next.f0;
key = next.f1;
concat.append(next.f2).append("-");
}
if (concat.length() > 0) {
concat.setLength(concat.length() - 1);
}
out.collect(new Tuple3<Integer, Long, String>(sum, key, concat.toString()));
}
@Override
public void reduce(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple2<Integer, String>> out) {
int i = 0;
String s = "";
for ( Tuple3<Integer, Long, String> t : values ) {
i += t.f0;
s = t.f2;
}
out.collect(new Tuple2<Integer, String>(i, s));
}
}
@RichGroupReduceFunction.Combinable
public static class Tuple3AllGroupReduceWithCombine extends RichGroupReduceFunction<Tuple3<Integer, Long, String>, Tuple2<Integer, String>> {
......
......@@ -595,6 +595,127 @@ class GroupReduceITCase(mode: ExecutionMode) extends MultipleProgramsTestBase(mo
expected = "1---(10,100)-\n" + "2---(30,600)-(30,400)-(30,200)-(20,201)-(20,200)-\n"
}
@Test
def testTupleKeySelectorGroupSort: Unit = {
/*
* check correctness of sorted groupReduce on tuples with keyselector sorting
*/
val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(1)
val ds = CollectionDataSets.get3TupleDataSet(env)
val reduceDs = ds.groupBy(_._2).sortGroup(_._3, Order.DESCENDING).reduceGroup {
in =>
in.reduce((l, r) => (l._1 + r._1, l._2, l._3 + "-" + r._3))
}
reduceDs.writeAsCsv(resultPath)
env.execute()
expected = "1,1,Hi\n" +
"5,2,Hello world-Hello\n" +
"15,3,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"65,5,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"111,6,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n"
}
@Test
def testPojoKeySelectorGroupSort: Unit = {
/*
* check correctness of sorted groupReduce on custom type with keyselector sorting
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getCustomTypeDataSet(env)
val reduceDs = ds.groupBy(_.myInt).sortGroup(_.myString, Order.DESCENDING).reduceGroup {
in =>
val iter = in.toIterator
val o = new CustomType
val c = iter.next()
val concat: StringBuilder = new StringBuilder(c.myString)
o.myInt = c.myInt
o.myLong = c.myLong
while (iter.hasNext) {
val next = iter.next()
o.myLong += next.myLong
concat.append("-").append(next.myString)
}
o.myString = concat.toString()
o
}
reduceDs.writeAsText(resultPath)
env.execute()
expected = "1,0,Hi\n" +
"2,3,Hello world-Hello\n" +
"3,12,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"4,30,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"5,60,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"6,105,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n"
}
@Test
def testTupleKeySelectorSortWithCombine: Unit = {
/*
* check correctness of sorted groupReduce with combine on tuples with keyselector sorting
*/
val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(1)
val ds = CollectionDataSets.get3TupleDataSet(env)
val reduceDs = ds.groupBy(_._2).sortGroup(_._3, Order.DESCENDING)
.reduceGroup(new Tuple3SortedGroupReduceWithCombine)
reduceDs.writeAsCsv(resultPath)
env.execute()
if (mode == ExecutionMode.COLLECTION) {
expected = null
} else {
expected = "1,Hi\n" +
"5,Hello world-Hello\n" +
"15,Luke Skywalker-I am fine.-Hello world, how are you?\n" +
"34,Comment#4-Comment#3-Comment#2-Comment#1\n" +
"65,Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" +
"111,Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n"
}
}
@Test
def testTupleKeySelectorSortCombineOnTuple: Unit = {
/*
* check correctness of sorted groupReduceon with Tuple2 keyselector sorting
*/
val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(1)
val ds = CollectionDataSets.get5TupleDataSet(env)
val reduceDs = ds.groupBy(_._1).sortGroup(t => (t._5, t._3), Order.DESCENDING).reduceGroup{
in =>
val iter = in.toIterator
val concat: StringBuilder = new StringBuilder
var sum: Long = 0
var key = 0
var s: Long = 0
while (iter.hasNext) {
val next = iter.next()
sum += next._2
key = next._1
s = next._5
concat.append(next._4).append("-")
}
if (concat.length > 0) {
concat.setLength(concat.length - 1)
}
(key, sum, 0, concat.toString(), s)
// in.reduce((l, r) => (l._1, l._2 + r._2, 0, l._4 + "-" + r._4, l._5))
}
reduceDs.writeAsCsv(resultPath)
env.execute()
expected = "1,1,0,Hallo,1\n" +
"2,5,0,Hallo Welt-Hallo Welt wie,1\n" +
"3,15,0,BCD-ABC-Hallo Welt wie gehts?,2\n" +
"4,34,0,FGH-CDE-EFG-DEF,1\n" +
"5,65,0,IJK-HIJ-KLM-JKL-GHI,1\n";
}
@Test
def testGroupingWithPojoContainingMultiplePojos: Unit = {
/*
......@@ -761,4 +882,38 @@ class NestedTupleReducer extends GroupReduceFunction[((Int, Int), String), Strin
}
}
@RichGroupReduceFunction.Combinable
class Tuple3SortedGroupReduceWithCombine
extends RichGroupReduceFunction[(Int, Long, String), (Int, String)] {
override def combine(
values: Iterable[(Int, Long, String)],
out: Collector[(Int, Long, String)]): Unit = {
val concat: StringBuilder = new StringBuilder
var sum = 0
var key: Long = 0
for (t <- values.asScala) {
sum += t._1
key = t._2
concat.append(t._3).append("-")
}
if (concat.length > 0) {
concat.setLength(concat.length - 1)
}
out.collect((sum, key, concat.toString()))
}
override def reduce(
values: Iterable[(Int, Long, String)],
out: Collector[(Int, String)]): Unit = {
var i = 0
var s = ""
for (t <- values.asScala) {
i += t._1
s = t._3
}
out.collect((i, s))
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册