提交 aabb2688 编写于 作者: K Kostas Kloudas 提交者: Fabian Hueske

[FLINK-3254] [dataSet] Adding functionality to support the CombineFunction contract.

This closes #1568
上级 f7d19115
......@@ -18,9 +18,10 @@
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.functions.CombineFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.java.operators.translation.CombineToGroupCombineWrapper;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
......@@ -30,6 +31,7 @@ import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.SemanticPropUtil;
import org.apache.flink.api.common.operators.Keys.SelectorFunctionKeys;
import org.apache.flink.api.common.operators.Keys.ExpressionKeys;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingSortedReduceGroupOperator;
import org.apache.flink.api.java.tuple.Tuple2;
......@@ -52,7 +54,7 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
private static final Logger LOG = LoggerFactory.getLogger(GroupReduceOperator.class);
private final GroupReduceFunction<IN, OUT> function;
private GroupReduceFunction<IN, OUT> function;
private final Grouping<IN> grouper;
......@@ -68,12 +70,12 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
*/
public GroupReduceOperator(DataSet<IN> input, TypeInformation<OUT> resultType, GroupReduceFunction<IN, OUT> function, String defaultName) {
super(input, resultType);
this.function = function;
this.grouper = null;
this.defaultName = defaultName;
checkCombinability();
this.combinable = checkCombinability();
}
/**
......@@ -84,18 +86,18 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
*/
public GroupReduceOperator(Grouping<IN> input, TypeInformation<OUT> resultType, GroupReduceFunction<IN, OUT> function, String defaultName) {
super(input != null ? input.getInputDataSet() : null, resultType);
this.function = function;
this.grouper = input;
this.defaultName = defaultName;
checkCombinability();
this.combinable = checkCombinability();
UdfOperatorUtils.analyzeSingleInputUdf(this, GroupReduceFunction.class, defaultName, function, grouper.keys);
}
private void checkCombinability() {
if (function instanceof GroupCombineFunction) {
private boolean checkCombinability() {
if (function instanceof GroupCombineFunction || function instanceof CombineFunction) {
// check if the generic types of GroupCombineFunction and GroupReduceFunction match, i.e.,
// GroupCombineFunction<IN, IN> and GroupReduceFunction<IN, OUT>.
......@@ -110,7 +112,9 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
if (((ParameterizedType) genInterface).getRawType().equals(GroupReduceFunction.class)) {
reduceTypes = ((ParameterizedType) genInterface).getActualTypeArguments();
// get parameters of GroupCombineFunction
} else if (((ParameterizedType) genInterface).getRawType().equals(GroupCombineFunction.class)) {
} else if ((((ParameterizedType) genInterface).getRawType().equals(GroupCombineFunction.class)) ||
(((ParameterizedType) genInterface).getRawType().equals(CombineFunction.class))) {
combineTypes = ((ParameterizedType) genInterface).getActualTypeArguments();
}
}
......@@ -120,24 +124,25 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
combineTypes != null && combineTypes.length == 2) {
if (reduceTypes[0].equals(combineTypes[0]) && reduceTypes[0].equals(combineTypes[1])) {
this.combinable = true;
return true;
} else {
LOG.warn("GroupCombineFunction cannot be used as combiner for GroupReduceFunction. " +
"Generic types are incompatible.");
this.combinable = false;
return false;
}
}
else if (reduceTypes == null || reduceTypes.length != 2) {
LOG.warn("Cannot check generic types of GroupReduceFunction. " +
"Enabling combiner but combine function might fail at runtime.");
this.combinable = true;
return true;
}
else {
LOG.warn("Cannot check generic types of GroupCombineFunction. " +
"Enabling combiner but combine function might fail at runtime.");
this.combinable = true;
return true;
}
}
return false;
}
......@@ -156,13 +161,18 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
}
public GroupReduceOperator<IN, OUT> setCombinable(boolean combinable) {
// sanity check that the function is a subclass of the combine interface
if (combinable && !(function instanceof GroupCombineFunction)) {
throw new IllegalArgumentException("The function does not implement the combine interface.");
if(combinable) {
// sanity check that the function is a subclass of the combine interface
if (!checkCombinability()) {
throw new IllegalArgumentException("Either the function does not implement a combine interface, " +
"or the types of the combine() and reduce() methods are not compatible.");
}
this.combinable = true;
}
else {
this.combinable = false;
}
this.combinable = combinable;
return this;
}
......@@ -191,10 +201,16 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
// --------------------------------------------------------------------------------------------
@Override
@SuppressWarnings("unchecked")
protected GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
String name = getName() != null ? getName() : "GroupReduce at " + defaultName;
// wrap CombineFunction in GroupCombineFunction if combinable
this.function = (combinable && function instanceof CombineFunction<?,?>) ?
new CombineToGroupCombineWrapper((CombineFunction<?,?>) function) :
function;
// distinguish between grouped reduce and non-grouped reduce
if (grouper == null) {
// non grouped reduce
......@@ -236,7 +252,7 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
return po;
}
}
else if (grouper.getKeys() instanceof Keys.ExpressionKeys) {
else if (grouper.getKeys() instanceof ExpressionKeys) {
int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions();
UnaryOperatorInformation<IN, OUT> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.translation;
import com.google.common.base.Preconditions;
import org.apache.flink.api.common.functions.CombineFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;
/**
* A wrapper the wraps a function that implements both {@link CombineFunction} and {@link GroupReduceFunction} interfaces
* and makes it look like a function that implements {@link GroupCombineFunction} and {@link GroupReduceFunction} to the runtime.
*/
public class CombineToGroupCombineWrapper<IN, OUT, F extends CombineFunction<IN, IN> & GroupReduceFunction<IN, OUT>>
implements GroupCombineFunction<IN, IN>, GroupReduceFunction<IN, OUT> {
private final F wrappedFunction;
public CombineToGroupCombineWrapper(F wrappedFunction) {
this.wrappedFunction = Preconditions.checkNotNull(wrappedFunction);
}
@Override
public void combine(Iterable<IN> values, Collector<IN> out) throws Exception {
IN outValue = wrappedFunction.combine(values);
out.collect(outValue);
}
@Override
public void reduce(Iterable<IN> values, Collector<OUT> out) throws Exception {
wrappedFunction.reduce(values, out);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.javaApiOperators;
import org.apache.flink.api.common.functions.CombineFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.UnsortedGrouping;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.util.Collector;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.List;
@SuppressWarnings("serial")
@RunWith(Parameterized.class)
public class ReduceWithCombinerITCase extends MultipleProgramsTestBase {
public ReduceWithCombinerITCase(TestExecutionMode mode) {
super(TestExecutionMode.CLUSTER);
}
@Test
public void testReduceOnNonKeyedDataset() throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple2<Integer, Boolean>> input = createNonKeyedInput(env);
List<Tuple2<Integer, Boolean>> actual = input.reduceGroup(new NonKeyedCombReducer()).collect();
String expected = "10,true\n";
compareResultAsTuples(actual, expected);
}
@Test
public void testForkingReduceOnNonKeyedDataset() throws Exception {
// set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple2<Integer, Boolean>> input = createNonKeyedInput(env);
DataSet<Tuple2<Integer, Boolean>> r1 = input.reduceGroup(new NonKeyedCombReducer());
DataSet<Tuple2<Integer, Boolean>> r2 = input.reduceGroup(new NonKeyedGroupCombReducer());
List<Tuple2<Integer, Boolean>> actual = r1.union(r2).collect();
String expected = "10,true\n10,true\n";
compareResultAsTuples(actual, expected);
}
@Test
public void testReduceOnKeyedDataset() throws Exception {
// set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple3<String, Integer, Boolean>> input = createKeyedInput(env);
List<Tuple3<String, Integer, Boolean>> actual = input.groupBy(0).reduceGroup(new KeyedCombReducer()).collect();
String expected = "k1,6,true\nk2,4,true\n";
compareResultAsTuples(actual, expected);
}
@Test
public void testReduceOnKeyedDatasetWithSelector() throws Exception {
// set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple3<String, Integer, Boolean>> input = createKeyedInput(env);
List<Tuple3<String, Integer, Boolean>> actual = input
.groupBy(new KeySelectorX())
.reduceGroup(new KeyedCombReducer())
.collect();
String expected = "k1,6,true\nk2,4,true\n";
compareResultAsTuples(actual, expected);
}
@Test
public void testForkingReduceOnKeyedDataset() throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple3<String, Integer, Boolean>> input = createKeyedInput(env);
UnsortedGrouping<Tuple3<String, Integer, Boolean>> counts = input.groupBy(0);
DataSet<Tuple3<String, Integer, Boolean>> r1 = counts.reduceGroup(new KeyedCombReducer());
DataSet<Tuple3<String, Integer, Boolean>> r2 = counts.reduceGroup(new KeyedGroupCombReducer());
List<Tuple3<String, Integer, Boolean>> actual = r1.union(r2).collect();
String expected = "k1,6,true\n" +
"k2,4,true\n" +
"k1,6,true\n" +
"k2,4,true\n";
compareResultAsTuples(actual, expected);
}
@Test
public void testForkingReduceOnKeyedDatasetWithSelection() throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// creates the input data and distributes them evenly among the available downstream tasks
DataSet<Tuple3<String, Integer, Boolean>> input = createKeyedInput(env);
UnsortedGrouping<Tuple3<String, Integer, Boolean>> counts = input.groupBy(new KeySelectorX());
DataSet<Tuple3<String, Integer, Boolean>> r1 = counts.reduceGroup(new KeyedCombReducer());
DataSet<Tuple3<String, Integer, Boolean>> r2 = counts.reduceGroup(new KeyedGroupCombReducer());
List<Tuple3<String, Integer, Boolean>> actual = r1.union(r2).collect();
String expected = "k1,6,true\n" +
"k2,4,true\n" +
"k1,6,true\n" +
"k2,4,true\n";
compareResultAsTuples(actual, expected);
}
private DataSet<Tuple2<Integer, Boolean>> createNonKeyedInput(ExecutionEnvironment env) {
return env.fromCollection(Arrays.asList(
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false),
new Tuple2<>(1, false))
).rebalance();
}
private static class NonKeyedCombReducer implements CombineFunction<Tuple2<Integer, Boolean>, Tuple2<Integer, Boolean>>,
GroupReduceFunction<Tuple2<Integer, Boolean>,Tuple2<Integer, Boolean>> {
@Override
public Tuple2<Integer, Boolean> combine(Iterable<Tuple2<Integer, Boolean>> values) throws Exception {
int sum = 0;
boolean flag = true;
for(Tuple2<Integer, Boolean> tuple : values) {
sum += tuple.f0;
flag &= !tuple.f1;
}
return new Tuple2<>(sum, flag);
}
@Override
public void reduce(Iterable<Tuple2<Integer, Boolean>> values, Collector<Tuple2<Integer, Boolean>> out) throws Exception {
int sum = 0;
boolean flag = true;
for(Tuple2<Integer, Boolean> tuple : values) {
sum += tuple.f0;
flag &= tuple.f1;
}
out.collect(new Tuple2<>(sum, flag));
}
}
private static class NonKeyedGroupCombReducer implements GroupCombineFunction<Tuple2<Integer, Boolean>, Tuple2<Integer, Boolean>>,
GroupReduceFunction<Tuple2<Integer, Boolean>,Tuple2<Integer, Boolean>> {
@Override
public void reduce(Iterable<Tuple2<Integer, Boolean>> values, Collector<Tuple2<Integer, Boolean>> out) throws Exception {
int sum = 0;
boolean flag = true;
for(Tuple2<Integer, Boolean> tuple : values) {
sum += tuple.f0;
flag &= tuple.f1;
}
out.collect(new Tuple2<>(sum, flag));
}
@Override
public void combine(Iterable<Tuple2<Integer, Boolean>> values, Collector<Tuple2<Integer, Boolean>> out) throws Exception {
int sum = 0;
boolean flag = true;
for(Tuple2<Integer, Boolean> tuple : values) {
sum += tuple.f0;
flag &= !tuple.f1;
}
out.collect(new Tuple2<>(sum, flag));
}
}
private DataSet<Tuple3<String, Integer, Boolean>> createKeyedInput(ExecutionEnvironment env) {
return env.fromCollection(Arrays.asList(
new Tuple3<>("k1", 1, false),
new Tuple3<>("k1", 1, false),
new Tuple3<>("k1", 1, false),
new Tuple3<>("k2", 1, false),
new Tuple3<>("k1", 1, false),
new Tuple3<>("k1", 1, false),
new Tuple3<>("k2", 1, false),
new Tuple3<>("k2", 1, false),
new Tuple3<>("k1", 1, false),
new Tuple3<>("k2", 1, false))
).rebalance();
}
public static class KeySelectorX implements KeySelector<Tuple3<String, Integer, Boolean>, String> {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple3<String, Integer, Boolean> in) {
return in.f0;
}
}
private class KeyedCombReducer implements CombineFunction<Tuple3<String, Integer, Boolean>, Tuple3<String, Integer, Boolean>>,
GroupReduceFunction<Tuple3<String, Integer, Boolean>, Tuple3<String, Integer, Boolean>> {
@Override
public Tuple3<String, Integer, Boolean> combine(Iterable<Tuple3<String, Integer, Boolean>> values) throws Exception {
String key = null;
int sum = 0;
boolean flag = true;
for(Tuple3<String, Integer, Boolean> tuple : values) {
key = (key == null) ? tuple.f0 : key;
sum += tuple.f1;
flag &= !tuple.f2;
}
return new Tuple3<>(key, sum, flag);
}
@Override
public void reduce(Iterable<Tuple3<String, Integer, Boolean>> values, Collector<Tuple3<String, Integer, Boolean>> out) throws Exception {
String key = null;
int sum = 0;
boolean flag = true;
for(Tuple3<String, Integer, Boolean> tuple : values) {
key = (key == null) ? tuple.f0 : key;
sum += tuple.f1;
flag &= tuple.f2;
}
out.collect(new Tuple3<>(key, sum, flag));
}
}
private class KeyedGroupCombReducer implements GroupCombineFunction<Tuple3<String, Integer, Boolean>, Tuple3<String, Integer, Boolean>>,
GroupReduceFunction<Tuple3<String, Integer, Boolean>, Tuple3<String, Integer, Boolean>> {
@Override
public void combine(Iterable<Tuple3<String, Integer, Boolean>> values, Collector<Tuple3<String, Integer, Boolean>> out) throws Exception {
String key = null;
int sum = 0;
boolean flag = true;
for(Tuple3<String, Integer, Boolean> tuple : values) {
key = (key == null) ? tuple.f0 : key;
sum += tuple.f1;
flag &= !tuple.f2;
}
out.collect(new Tuple3<>(key, sum, flag));
}
@Override
public void reduce(Iterable<Tuple3<String, Integer, Boolean>> values, Collector<Tuple3<String, Integer, Boolean>> out) throws Exception {
String key = null;
int sum = 0;
boolean flag = true;
for(Tuple3<String, Integer, Boolean> tuple : values) {
key = (key == null) ? tuple.f0 : key;
sum += tuple.f1;
flag &= tuple.f2;
}
out.collect(new Tuple3<>(key, sum, flag));
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册