提交 3eac6f23 编写于 作者: S Stephan Ewen

[FLINK-1110] Implement collection-based execution for mapPartition.

Make groupReduce code compliant with pre-java-8 versions, fix java8 tests with moved type information classes.

Fix Various warnings.
上级 fd3f5c27
......@@ -18,14 +18,16 @@
package org.apache.flink.api.common.operators.base;
import java.util.List;
import org.apache.flink.api.common.functions.GenericCollectorMap;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
/**
* The CollectorMap is the old version of the Map operator. It is effectively a "flatMap", where the
* UDF is called "map".
......@@ -46,4 +48,11 @@ public class CollectorMapOperatorBase<IN, OUT, FT extends GenericCollectorMap<IN
public CollectorMapOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<IN, OUT> operatorInfo, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx) {
throw new UnsupportedOperationException();
}
}
......@@ -37,6 +37,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
......@@ -144,6 +145,7 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
int[] inputColumns = getKeyColumns(0);
boolean[] inputOrderings = new boolean[inputColumns.length];
@SuppressWarnings("unchecked")
final TypeComparator<IN> inputComparator =
((CompositeType<IN>) inputType).createComparator(inputColumns, inputOrderings);
......@@ -154,10 +156,10 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
ArrayList<OUT> result = new ArrayList<OUT>(inputData.size());
ListCollector<OUT> collector = new ListCollector<OUT>(result);
inputData.sort( new Comparator<IN>() {
Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return - inputComparator.compare(o1, o2);
return inputComparator.compare(o2, o1);
}
});
ListKeyGroupedIterator<IN> keyedIterator =
......
......@@ -18,7 +18,13 @@
package org.apache.flink.api.common.operators.base;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
......@@ -44,4 +50,23 @@ public class MapPartitionOperatorBase<IN, OUT, FT extends MapPartitionFunction<I
public MapPartitionOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<IN, OUT> operatorInfo, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx) throws Exception {
MapPartitionFunction<IN, OUT> function = this.userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
ArrayList<OUT> result = new ArrayList<OUT>(inputData.size() / 4);
ListCollector<OUT> resultCollector = new ListCollector<OUT>(result);
function.mapPartition(inputData, resultCollector);
result.trimToSize();
FunctionUtils.closeFunction(function);
return result;
}
}
......@@ -34,6 +34,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@SuppressWarnings("serial")
public class FlatMapOperatorCollectionExecutionTest implements Serializable {
@Test
......@@ -61,6 +62,7 @@ public class FlatMapOperatorCollectionExecutionTest implements Serializable {
}
}
public class IdRichFlatMap<IN> extends RichFlatMapFunction<IN, IN> {
private boolean isOpened = false;
......
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import static org.junit.Assert.*;
import static java.util.Arrays.asList;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.util.Collector;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.junit.Test;
@SuppressWarnings("serial")
public class PartitionMapOperatorTest implements java.io.Serializable {
@Test
public void testMapPartitionWithRuntimeContext() {
try {
final String taskName = "Test Task";
final AtomicBoolean opened = new AtomicBoolean();
final AtomicBoolean closed = new AtomicBoolean();
final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.set(true);
RuntimeContext ctx = getRuntimeContext();
assertEquals(0, ctx.getIndexOfThisSubtask());
assertEquals(1, ctx.getNumberOfParallelSubtasks());
assertEquals(taskName, ctx.getTaskName());
}
@Override
public void mapPartition(Iterable<String> values, Collector<Integer> out) {
for (String s : values) {
out.collect(Integer.parseInt(s));
}
}
@Override
public void close() throws Exception {
closed.set(true);
}
};
MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op =
new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
List<Integer> result = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0));
assertEquals(asList(1, 2, 3, 4, 5, 6), result);
assertTrue(opened.get());
assertTrue(closed.get());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
}
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators;
import static org.junit.Assert.*;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.CollectionEnvironment;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOuputFormat;
import org.apache.flink.configuration.Configuration;
import org.junit.Test;
public class CollectionExecutionAccumulatorsTest {
private static final String ACCUMULATOR_NAME = "TEST ACC";
@Test
public void testAccumulator() {
try {
final int NUM_ELEMENTS = 100;
ExecutionEnvironment env = new CollectionEnvironment();
env.generateSequence(1, NUM_ELEMENTS)
.map(new CountingMapper())
.output(new DiscardingOuputFormat<Long>());
JobExecutionResult result = env.execute();
assertTrue(result.getNetRuntime() >= 0);
assertEquals(NUM_ELEMENTS, result.getAccumulatorResult(ACCUMULATOR_NAME));
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@SuppressWarnings("serial")
public static class CountingMapper extends RichMapFunction<Long, Long> {
private IntCounter accumulator;
@Override
public void open(Configuration parameters) {
accumulator = getRuntimeContext().getIntCounter(ACCUMULATOR_NAME);
}
@Override
public Long map(Long value) {
accumulator.add(1);
return value;
}
}
}
......@@ -19,9 +19,7 @@
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
......
......@@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import junit.framework.Assert;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.DataSet;
......@@ -30,6 +28,7 @@ import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.UnsortedGrouping;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.junit.Assert;
import org.junit.Test;
public class MaxByOperatorTest {
......
......@@ -21,7 +21,7 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import junit.framework.Assert;
import org.junit.Assert;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
......
......@@ -69,7 +69,7 @@ public class ReduceITCase extends JavaProgramTestBase {
BasicTypeInfo.LONG_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.LONG_TYPE_INFO
BasicTypeInfo<T>.LONG_TYPE_INFO
);
return env.fromCollection(data, type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册