提交 ac69cb3e 编写于 作者: S Stephan Ewen

[FLINK-1110] By default, collection-based execution behaves mutable-object safe.

上级 3fd3110c
......@@ -43,7 +43,7 @@ public class NoOpBinaryUdfOp<OUT> extends DualInputOperator<OUT, OUT, OUT, NoOpF
}
@Override
protected List<OUT> executeOnCollections(List<OUT> inputData1, List<OUT> inputData2, RuntimeContext runtimeContext) {
protected List<OUT> executeOnCollections(List<OUT> inputData1, List<OUT> inputData2, RuntimeContext runtimeContext, boolean mutables) {
throw new UnsupportedOperationException();
}
}
......
......@@ -54,7 +54,7 @@ public class NoOpUnaryUdfOp<OUT> extends SingleInputOperator<OUT, OUT, NoOpFunct
}
@Override
protected List<OUT> executeOnCollections(List<OUT> inputData, RuntimeContext runtimeContext) {
protected List<OUT> executeOnCollections(List<OUT> inputData, RuntimeContext runtimeContext, boolean mutables) {
return inputData;
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.functions.util;
import java.util.Iterator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.util.TraversableOnceException;
public class CopyingIterator<E> implements Iterator<E>, Iterable<E> {
private final Iterator<E> source;
private final TypeSerializer<E> serializer;
private boolean available = true;
public CopyingIterator(Iterator<E> source, TypeSerializer<E> serializer) {
this.source = source;
this.serializer = serializer;
}
@Override
public Iterator<E> iterator() {
if (available) {
available = false;
return this;
} else {
throw new TraversableOnceException();
}
}
@Override
public boolean hasNext() {
return source.hasNext();
}
@Override
public E next() {
E next = source.next();
return serializer.copy(next);
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.functions.util;
import java.util.List;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.util.Collector;
public class CopyingListCollector<T> implements Collector<T> {
private final List<T> list;
private final TypeSerializer<T> serializer;
public CopyingListCollector(List<T> list, TypeSerializer<T> serializer) {
this.list = list;
this.serializer = serializer;
}
@Override
public void collect(T record) {
list.add(serializer.copy(record));
}
@Override
public void close() {}
}
......@@ -52,13 +52,22 @@ import org.apache.flink.util.Visitor;
*/
public class CollectionExecutor {
private static final boolean DEFAULT_MUTABLE_OBJECT_SAFE_MODE = true;
private final Map<Operator<?>, List<?>> intermediateResults;
private final Map<String, Accumulator<?, ?>> accumulators;
private final boolean mutableObjectSafeMode;
// --------------------------------------------------------------------------------------------
public CollectionExecutor() {
this(DEFAULT_MUTABLE_OBJECT_SAFE_MODE);
}
public CollectionExecutor(boolean mutableObjectSafeMode) {
this.mutableObjectSafeMode = mutableObjectSafeMode;
this.intermediateResults = new HashMap<Operator<?>, List<?>>();
this.accumulators = new HashMap<String, Accumulator<?,?>>();
}
......@@ -172,7 +181,7 @@ public class CollectionExecutor {
ctx = null;
}
List<OUT> result = typedOp.executeOnCollections(inputData, ctx);
List<OUT> result = typedOp.executeOnCollections(inputData, ctx, mutableObjectSafeMode);
if (ctx != null) {
AccumulatorHelper.mergeInto(this.accumulators, ctx.getAllAccumulators());
......@@ -214,7 +223,7 @@ public class CollectionExecutor {
ctx = null;
}
List<OUT> result = typedOp.executeOnCollections(inputData1, inputData2, ctx);
List<OUT> result = typedOp.executeOnCollections(inputData1, inputData2, ctx, mutableObjectSafeMode);
if (ctx != null) {
AccumulatorHelper.mergeInto(this.accumulators, ctx.getAllAccumulators());
......
......@@ -286,5 +286,5 @@ public abstract class DualInputOperator<IN1, IN2, OUT, FT extends Function> exte
// --------------------------------------------------------------------------------------------
protected abstract List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception;
protected abstract List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) throws Exception;
}
......@@ -209,5 +209,5 @@ public abstract class SingleInputOperator<IN, OUT, FT extends Function> extends
// --------------------------------------------------------------------------------------------
protected abstract List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext runtimeContext) throws Exception;
protected abstract List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) throws Exception;
}
......@@ -48,7 +48,7 @@ public class Union<T> extends DualInputOperator<T, T, T, AbstractRichFunction> {
}
@Override
protected List<T> executeOnCollections(List<T> inputData1, List<T> inputData2, RuntimeContext runtimeContext) {
protected List<T> executeOnCollections(List<T> inputData1, List<T> inputData2, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) {
ArrayList<T> result = new ArrayList<T>(inputData1.size() + inputData2.size());
result.addAll(inputData1);
result.addAll(inputData2);
......
......@@ -303,7 +303,7 @@ public class BulkIterationBase<T> extends SingleInputOperator<T, T, AbstractRich
}
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext runtimeContext) {
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) {
throw new UnsupportedOperationException();
}
}
......@@ -21,6 +21,7 @@ package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
......@@ -35,6 +36,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.util.Collector;
import java.io.IOException;
......@@ -177,7 +179,7 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1,
// ------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN1> input1, List<IN2> input2, RuntimeContext ctx) throws Exception {
protected List<OUT> executeOnCollections(List<IN1> input1, List<IN2> input2, RuntimeContext ctx, boolean mutableObjectSafe) throws Exception {
// --------------------------------------------------------------------
// Setup
// --------------------------------------------------------------------
......@@ -193,11 +195,15 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1,
Arrays.fill(inputSortDirections1, true);
Arrays.fill(inputSortDirections2, true);
final TypeSerializer<IN1> inputSerializer1 = inputType1.createSerializer();
final TypeSerializer<IN2> inputSerializer2 = inputType2.createSerializer();
final TypeComparator<IN1> inputComparator1 = getTypeComparator(inputType1, inputKeys1, inputSortDirections1);
final TypeComparator<IN2> inputComparator2 = getTypeComparator(inputType2, inputKeys2, inputSortDirections2);
CoGroupSortListIterator<IN1, IN2> coGroupIterator =
new CoGroupSortListIterator<IN1, IN2>(input1, inputComparator1, input2, inputComparator2);
new CoGroupSortListIterator<IN1, IN2>(input1, inputComparator1, inputSerializer1,
input2, inputComparator2, inputSerializer2, mutableObjectSafe);
// --------------------------------------------------------------------
// Run UDF
......@@ -208,7 +214,9 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1,
FunctionUtils.openFunction(function, parameters);
List<OUT> result = new ArrayList<OUT>();
Collector<OUT> resultCollector = new ListCollector<OUT>(result);
Collector<OUT> resultCollector = mutableObjectSafe ?
new CopyingListCollector<OUT>(result, getOperatorInfo().getOutputType().createSerializer()) :
new ListCollector<OUT>(result);
while (coGroupIterator.next()) {
function.coGroup(coGroupIterator.getValues1(), coGroupIterator.getValues2(), resultCollector);
......@@ -247,13 +255,14 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1,
private Iterable<IN2> secondReturn;
private CoGroupSortListIterator(
List<IN1> input1, final TypeComparator<IN1> inputComparator1,
List<IN2> input2, final TypeComparator<IN2> inputComparator2) {
List<IN1> input1, final TypeComparator<IN1> inputComparator1, TypeSerializer<IN1> serializer1,
List<IN2> input2, final TypeComparator<IN2> inputComparator2, TypeSerializer<IN2> serializer2,
boolean copyElements)
{
this.pairComparator = new GenericPairComparator<IN1, IN2>(inputComparator1, inputComparator2);
this.iterator1 = new ListKeyGroupedIterator<IN1>(input1, inputComparator1);
this.iterator2 = new ListKeyGroupedIterator<IN2>(input2, inputComparator2);
this.iterator1 = new ListKeyGroupedIterator<IN1>(input1, serializer1, inputComparator1, copyElements);
this.iterator2 = new ListKeyGroupedIterator<IN2>(input2, serializer2, inputComparator2, copyElements);
// ----------------------------------------------------------------
// Sort
......
......@@ -52,7 +52,7 @@ public class CollectorMapOperatorBase<IN, OUT, FT extends GenericCollectorMap<IN
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx) {
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) {
throw new UnsupportedOperationException();
}
}
......@@ -29,6 +29,7 @@ import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializer;
/**
* @see org.apache.flink.api.common.functions.CrossFunction
......@@ -50,21 +51,37 @@ public class CrossOperatorBase<IN1, IN2, OUT, FT extends CrossFunction<IN1, IN2,
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext ctx) throws Exception {
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
CrossFunction<IN1, IN2, OUT> function = this.userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
ArrayList<OUT> result = new ArrayList<OUT>(inputData1.size() * inputData2.size());
for (IN1 element1 : inputData1) {
for (IN2 element2 : inputData2) {
result.add(function.cross(element1, element2));
if (mutableObjectSafeMode) {
TypeSerializer<IN1> inSerializer1 = getOperatorInfo().getFirstInputType().createSerializer();
TypeSerializer<IN2> inSerializer2 = getOperatorInfo().getSecondInputType().createSerializer();
TypeSerializer<OUT> outSerializer = getOperatorInfo().getOutputType().createSerializer();
for (IN1 element1 : inputData1) {
for (IN2 element2 : inputData2) {
IN1 copy1 = inSerializer1.copy(element1);
IN2 copy2 = inSerializer2.copy(element2);
OUT o = function.cross(copy1, copy2);
result.add(outSerializer.copy(o));
}
}
}
else {
for (IN1 element1 : inputData1) {
for (IN2 element2 : inputData2) {
result.add(function.cross(element1, element2));
}
}
}
FunctionUtils.closeFunction(function);
return result;
}
......
......@@ -16,7 +16,6 @@
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import java.util.Collections;
......@@ -333,7 +332,7 @@ public class DeltaIterationBase<ST, WT> extends DualInputOperator<ST, WT, ST, Ab
}
@Override
protected List<ST> executeOnCollections(List<ST> inputData1, List<WT> inputData2, RuntimeContext runtimeContext) throws Exception {
protected List<ST> executeOnCollections(List<ST> inputData1, List<WT> inputData2, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) {
throw new UnsupportedOperationException();
}
}
......@@ -16,7 +16,6 @@
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.io.FileOutputFormat;
......@@ -72,8 +71,7 @@ public class FileDataSinkBase<IN> extends GenericDataSinkBase<IN> {
*
* @return The path to which the output shall be written.
*/
public String getFilePath()
{
public String getFilePath() {
return this.filePath;
}
......@@ -82,5 +80,4 @@ public class FileDataSinkBase<IN> extends GenericDataSinkBase<IN> {
public String toString() {
return this.filePath;
}
}
......@@ -50,7 +50,7 @@ public class FilterOperatorBase<T, FT extends FlatMapFunction<T, T>> extends Sin
}
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx) throws Exception {
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
FlatMapFunction<T, T> function = this.userFunction.getUserCodeObject();
FunctionUtils.openFunction(function, this.parameters);
......
......@@ -20,6 +20,7 @@ package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.SingleInputOperator;
......@@ -27,6 +28,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import java.util.ArrayList;
import java.util.List;
......@@ -51,17 +53,29 @@ public class FlatMapOperatorBase<IN, OUT, FT extends FlatMapFunction<IN, OUT>> e
// ------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> input, RuntimeContext ctx) throws Exception {
protected List<OUT> executeOnCollections(List<IN> input, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
FlatMapFunction<IN, OUT> function = userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, parameters);
ArrayList<OUT> result = new ArrayList<OUT>(input.size());
ListCollector<OUT> resultCollector = new ListCollector<OUT>(result);
for (IN element : input) {
function.flatMap(element, resultCollector);
if (mutableObjectSafeMode) {
TypeSerializer<IN> inSerializer = getOperatorInfo().getInputType().createSerializer();
TypeSerializer<OUT> outSerializer = getOperatorInfo().getOutputType().createSerializer();
CopyingListCollector<OUT> resultCollector = new CopyingListCollector<OUT>(result, outSerializer);
for (IN element : input) {
IN inCopy = inSerializer.copy(element);
function.flatMap(inCopy, resultCollector);
}
} else {
ListCollector<OUT> resultCollector = new ListCollector<OUT>(result);
for (IN element : input) {
function.flatMap(element, resultCollector);
}
}
FunctionUtils.closeFunction(function);
......
......@@ -18,11 +18,11 @@
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.Ordering;
......@@ -35,6 +35,7 @@ import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import java.util.ArrayList;
import java.util.Collections;
......@@ -131,20 +132,21 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx)
throws Exception {
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
GroupReduceFunction<IN, OUT> function = this.userFunction.getUserCodeObject();
UnaryOperatorInformation<IN, OUT> operatorInfo = getOperatorInfo();
TypeInformation<IN> inputType = operatorInfo.getInputType();
if (!(inputType instanceof CompositeType)) {
throw new InvalidProgramException("Input type of groupReduce operation must be" +
" composite type.");
throw new InvalidProgramException("Input type of groupReduce operation must be a composite type.");
}
int[] inputColumns = getKeyColumns(0);
boolean[] inputOrderings = new boolean[inputColumns.length];
final TypeSerializer<IN> inputSerializer = inputType.createSerializer();
@SuppressWarnings("unchecked")
final TypeComparator<IN> inputComparator =
((CompositeType<IN>) inputType).createComparator(inputColumns, inputOrderings);
......@@ -152,26 +154,34 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
ArrayList<OUT> result = new ArrayList<OUT>(inputData.size());
ListCollector<OUT> collector = new ListCollector<OUT>(result);
Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return inputComparator.compare(o2, o1);
}
});
ListKeyGroupedIterator<IN> keyedIterator =
new ListKeyGroupedIterator<IN>(inputData, inputComparator);
while (keyedIterator.nextKey()) {
function.reduce(keyedIterator.getValues(), collector);
ListKeyGroupedIterator<IN> keyedIterator = new ListKeyGroupedIterator<IN>(
inputData, inputSerializer, inputComparator, mutableObjectSafeMode);
ArrayList<OUT> result = new ArrayList<OUT>();
if (mutableObjectSafeMode) {
TypeSerializer<OUT> outSerializer = getOperatorInfo().getOutputType().createSerializer();
CopyingListCollector<OUT> collector = new CopyingListCollector<OUT>(result, outSerializer);
while (keyedIterator.nextKey()) {
function.reduce(keyedIterator.getValues(), collector);
}
}
else {
ListCollector<OUT> collector = new ListCollector<OUT>(result);
while (keyedIterator.nextKey()) {
function.reduce(keyedIterator.getValues(), collector);
}
}
FunctionUtils.closeFunction(function);
return result;
}
}
......@@ -16,11 +16,11 @@
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
......@@ -34,6 +34,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Arrays;
......@@ -60,7 +62,7 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
@SuppressWarnings("unchecked")
@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception {
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext, boolean mutableObjectSafe) throws Exception {
FlatJoinFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, runtimeContext);
......@@ -68,13 +70,18 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
TypeInformation<IN1> leftInformation = getOperatorInfo().getFirstInputType();
TypeInformation<IN2> rightInformation = getOperatorInfo().getSecondInputType();
TypeInformation<OUT> outInformation = getOperatorInfo().getOutputType();
TypeSerializer<IN1> leftSerializer = mutableObjectSafe ? leftInformation.createSerializer() : null;
TypeSerializer<IN2> rightSerializer = mutableObjectSafe ? rightInformation.createSerializer() : null;
TypeComparator<IN1> leftComparator;
TypeComparator<IN2> rightComparator;
if(leftInformation instanceof AtomicType){
if (leftInformation instanceof AtomicType){
leftComparator = ((AtomicType<IN1>) leftInformation).createComparator(true);
}else if(leftInformation instanceof CompositeType){
}
else if(leftInformation instanceof CompositeType){
int[] keyPositions = getKeyColumns(0);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);
......@@ -102,12 +109,13 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
rightComparator);
List<OUT> result = new ArrayList<OUT>();
ListCollector<OUT> collector = new ListCollector<OUT>(result);
Collector<OUT> collector = mutableObjectSafe ? new CopyingListCollector<OUT>(result, outInformation.createSerializer())
: new ListCollector<OUT>(result);
Map<Integer, List<IN2>> probeTable = new HashMap<Integer, List<IN2>>();
//Build probe table
for(IN2 element: inputData2){
//Build hash table
for (IN2 element: inputData2){
List<IN2> list = probeTable.get(rightComparator.hash(element));
if(list == null){
list = new ArrayList<IN2>();
......@@ -118,15 +126,18 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
}
//Probing
for(IN1 left: inputData1){
for (IN1 left: inputData1) {
List<IN2> matchingHashes = probeTable.get(leftComparator.hash(left));
pairComparator.setReference(left);
if(matchingHashes != null){
for(IN2 right: matchingHashes){
if(pairComparator.equalToReference(right)){
function.join(left, right, collector);
if (matchingHashes != null) {
pairComparator.setReference(left);
for (IN2 right : matchingHashes){
if (pairComparator.equalToReference(right)) {
if (mutableObjectSafe) {
function.join(leftSerializer.copy(left), rightSerializer.copy(right), collector);
} else {
function.join(left, right, collector);
}
}
}
}
......
......@@ -16,7 +16,6 @@
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import java.util.ArrayList;
......@@ -30,7 +29,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializer;
/**
*
......@@ -55,15 +54,27 @@ public class MapOperatorBase<IN, OUT, FT extends MapFunction<IN, OUT>> extends S
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx) throws Exception {
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
MapFunction<IN, OUT> function = this.userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
ArrayList<OUT> result = new ArrayList<OUT>(inputData.size());
for (IN element : inputData) {
result.add(function.map(element));
if (mutableObjectSafeMode) {
TypeSerializer<IN> inSerializer = getOperatorInfo().getInputType().createSerializer();
TypeSerializer<OUT> outSerializer = getOperatorInfo().getOutputType().createSerializer();
for (IN element : inputData) {
IN inCopy = inSerializer.copy(element);
OUT out = function.map(inCopy);
result.add(outSerializer.copy(out));
}
} else {
for (IN element : inputData) {
result.add(function.map(element));
}
}
FunctionUtils.closeFunction(function);
......
......@@ -23,6 +23,8 @@ import java.util.List;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingIterator;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.SingleInputOperator;
......@@ -30,6 +32,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializer;
/**
*
......@@ -54,18 +57,28 @@ public class MapPartitionOperatorBase<IN, OUT, FT extends MapPartitionFunction<I
// --------------------------------------------------------------------------------------------
@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx) throws Exception {
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
MapPartitionFunction<IN, OUT> function = this.userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);
ArrayList<OUT> result = new ArrayList<OUT>(inputData.size() / 4);
ListCollector<OUT> resultCollector = new ListCollector<OUT>(result);
function.mapPartition(inputData, resultCollector);
if (mutableObjectSafeMode) {
TypeSerializer<IN> inSerializer = getOperatorInfo().getInputType().createSerializer();
TypeSerializer<OUT> outSerializer = getOperatorInfo().getOutputType().createSerializer();
CopyingIterator<IN> source = new CopyingIterator<IN>(inputData.iterator(), inSerializer);
CopyingListCollector<OUT> resultCollector = new CopyingListCollector<OUT>(result, outSerializer);
function.mapPartition(source, resultCollector);
} else {
ListCollector<OUT> resultCollector = new ListCollector<OUT>(result);
function.mapPartition(inputData, resultCollector);
}
result.trimToSize();
FunctionUtils.closeFunction(function);
return result;
}
......
......@@ -18,6 +18,9 @@
package org.apache.flink.api.common.operators.base;
import java.util.List;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.NoOpFunction;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
......@@ -51,4 +54,8 @@ public class PartitionOperatorBase<IN> extends SingleInputOperator<IN, IN, NoOpF
RANGE;
}
@Override
protected List<IN> executeOnCollections(List<IN> inputData, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) {
return inputData;
}
}
......@@ -16,7 +16,6 @@
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.InvalidProgramException;
......@@ -32,13 +31,14 @@ import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Base data flow operator for Reduce user-defined functions. Accepts reduce functions
* and key positions. The key positions are expected in the flattened common data model.
......@@ -123,20 +123,23 @@ public class ReduceOperatorBase<T, FT extends ReduceFunction<T>> extends SingleI
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}
// --------------------------------------------------------------------------------------------
// --------------------------------------------------------------------------------------------
@SuppressWarnings("unchecked")
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx)
throws Exception {
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
// make sure we can handle empty inputs
if (inputData.isEmpty()) {
return Collections.emptyList();
}
ReduceFunction<T> function = this.userFunction.getUserCodeObject();
UnaryOperatorInformation<T, T> operatorInfo = getOperatorInfo();
TypeInformation<T> inputType = operatorInfo.getInputType();
if (!(inputType instanceof CompositeType)) {
throw new InvalidProgramException("Input type of groupReduce operation must be" +
" composite type.");
throw new InvalidProgramException("Input type of groupReduce operation must be" + " composite type.");
}
FunctionUtils.setFunctionRuntimeContext(function, ctx);
......@@ -161,22 +164,30 @@ public class ReduceOperatorBase<T, FT extends ReduceFunction<T>> extends SingleI
aggregateMap.put(wrapper, result);
}
List<T> result = new ArrayList<T>(aggregateMap.values().size());
result.addAll(aggregateMap.values());
FunctionUtils.closeFunction(function);
return result;
} else {
return new ArrayList<T>(aggregateMap.values());
}
else {
T aggregate = inputData.get(0);
for (int i = 1; i < inputData.size(); i++) {
aggregate = function.reduce(aggregate, inputData.get(i));
if (mutableObjectSafeMode) {
TypeSerializer<T> serializer = getOperatorInfo().getInputType().createSerializer();
aggregate = serializer.copy(aggregate);
for (int i = 1; i < inputData.size(); i++) {
T next = function.reduce(aggregate, serializer.copy(inputData.get(i)));
aggregate = serializer.copy(next);
}
}
else {
for (int i = 1; i < inputData.size(); i++) {
aggregate = function.reduce(aggregate, inputData.get(i));
}
}
List<T> result = new ArrayList<T>(1);
result.add(aggregate);
FunctionUtils.setFunctionRuntimeContext(function, ctx);
return result;
return Collections.singletonList(aggregate);
}
}
}
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -19,6 +19,7 @@
package org.apache.flink.api.common.operators.util;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import java.io.IOException;
import java.util.Iterator;
......@@ -32,14 +33,16 @@ public final class ListKeyGroupedIterator<E> {
private final List<E> input;
private final TypeSerializer<E> serializer; // != null if the elements should be copied
private final TypeComparator<E> comparator;
private ValuesIterator valuesIterator;
private int currentPosition = 0;
private E lookahead = null;
private E lookahead;
private boolean done;
/**
......@@ -48,13 +51,13 @@ public final class ListKeyGroupedIterator<E> {
* @param input The list with the input elements.
* @param comparator The comparator for the data type iterated over.
*/
public ListKeyGroupedIterator(List<E> input, TypeComparator<E> comparator)
{
public ListKeyGroupedIterator(List<E> input, TypeSerializer<E> serializer, TypeComparator<E> comparator, boolean copy) {
if (input == null || comparator == null) {
throw new NullPointerException();
}
this.input = input;
this.serializer = copy ? serializer : null;
this.comparator = comparator;
this.done = input.isEmpty() ? true : false;
......@@ -109,7 +112,7 @@ public final class ListKeyGroupedIterator<E> {
E first = input.get(currentPosition++);
if (first != null) {
this.comparator.setReference(first);
this.valuesIterator = new ValuesIterator(first);
this.valuesIterator = new ValuesIterator(first, serializer);
return true;
}
else {
......@@ -155,9 +158,12 @@ public final class ListKeyGroupedIterator<E> {
public final class ValuesIterator implements Iterator<E>, Iterable<E> {
private E next;
private final TypeSerializer<E> serializer;
private ValuesIterator(E first) {
private ValuesIterator(E first, TypeSerializer<E> serializer) {
this.next = first;
this.serializer = serializer;
}
@Override
......@@ -170,7 +176,7 @@ public final class ListKeyGroupedIterator<E> {
if (this.next != null) {
E current = this.next;
this.next = ListKeyGroupedIterator.this.advanceToNext();
return current;
return serializer != null ? serializer.copy(current) : current;
} else {
throw new NoSuchElementException();
}
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -30,8 +30,8 @@ import org.junit.Assert;
import org.junit.Test;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@SuppressWarnings("serial")
......@@ -41,25 +41,34 @@ public class FlatMapOperatorCollectionTest implements Serializable {
public void testExecuteOnCollection() {
try {
IdRichFlatMap<String> udf = new IdRichFlatMap<String>();
testExecuteOnCollection(udf, Arrays.asList("f", "l", "i", "n", "k"));
testExecuteOnCollection(udf, Arrays.asList("f", "l", "i", "n", "k"), true);
Assert.assertTrue(udf.isClosed);
testExecuteOnCollection(new IdRichFlatMap<String>(), new ArrayList<String>());
} catch (Throwable t) {
Assert.fail(t.getMessage());
udf = new IdRichFlatMap<String>();
testExecuteOnCollection(udf, Arrays.asList("f", "l", "i", "n", "k"), false);
Assert.assertTrue(udf.isClosed);
udf = new IdRichFlatMap<String>();
testExecuteOnCollection(udf, Collections.<String>emptyList(), true);
Assert.assertTrue(udf.isClosed);
udf = new IdRichFlatMap<String>();
testExecuteOnCollection(udf, Collections.<String>emptyList(), false);
Assert.assertTrue(udf.isClosed);
}
catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
private void testExecuteOnCollection(FlatMapFunction<String, String> udf, List<String> input) throws Exception {
private void testExecuteOnCollection(FlatMapFunction<String, String> udf, List<String> input, boolean mutableSafe) throws Exception {
// run on collections
final List<String> result = getTestFlatMapOperator(udf)
.executeOnCollections(input, new RuntimeUDFContext("Test UDF", 4, 0));
.executeOnCollections(input, new RuntimeUDFContext("Test UDF", 4, 0), mutableSafe);
Assert.assertEquals(input.size(), result.size());
for (int i = 0; i < input.size(); i++) {
Assert.assertEquals(input.get(i), result.get(i));
}
Assert.assertEquals(input, result);
}
......
......@@ -60,10 +60,13 @@ public class JoinOperatorBaseTest implements Serializable {
List<Integer> expected = new ArrayList<Integer>(Arrays.asList(3, 3, 6 ,6));
try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, null);
List<Integer> resultSafe = base.executeOnCollections(inputData1, inputData2, null, true);
List<Integer> resultRegular = base.executeOnCollections(inputData1, inputData2, null, false);
assertEquals(expected, result);
} catch (Exception e) {
assertEquals(expected, resultSafe);
assertEquals(expected, resultRegular);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
......@@ -107,11 +110,13 @@ public class JoinOperatorBaseTest implements Serializable {
try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName,
1, 0));
List<Integer> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0), true);
List<Integer> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0), false);
assertEquals(expected, result);
} catch (Exception e) {
assertEquals(expected, resultSafe);
assertEquals(expected, resultRegular);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
......
......@@ -52,9 +52,11 @@ public class MapOperatorTest implements java.io.Serializable {
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), "TestMapper");
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
List<Integer> result = op.executeOnCollections(input, null);
List<Integer> resultMutableSafe = op.executeOnCollections(input, null, true);
List<Integer> resultRegular = op.executeOnCollections(input, null, false);
assertEquals(asList(1, 2, 3, 4, 5, 6), result);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
}
catch (Exception e) {
e.printStackTrace();
......@@ -95,9 +97,11 @@ public class MapOperatorTest implements java.io.Serializable {
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
List<Integer> result = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0));
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), true);
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), false);
assertEquals(asList(1, 2, 3, 4, 5, 6), result);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
assertTrue(opened.get());
assertTrue(closed.get());
......
......@@ -74,9 +74,12 @@ public class PartitionMapOperatorTest implements java.io.Serializable {
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
List<Integer> result = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0));
assertEquals(asList(1, 2, 3, 4, 5, 6), result);
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), true);
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), false);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
assertTrue(opened.get());
assertTrue(closed.get());
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -20,7 +20,6 @@ package org.apache.flink.api.common.operators;
//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
import static org.junit.Assert.*;
//CHECKSTYLE.ON: AvoidStarImport
import java.util.ArrayList;
import java.util.List;
......@@ -102,11 +101,14 @@ public class CollectionExecutionIterationTest implements java.io.Serializable {
try {
ExecutionEnvironment env = new CollectionEnvironment();
@SuppressWarnings("unchecked")
DataSet<Tuple2<Integer, Integer>> solInput = env.fromElements(
new Tuple2<Integer, Integer>(1, 0),
new Tuple2<Integer, Integer>(2, 0),
new Tuple2<Integer, Integer>(3, 0),
new Tuple2<Integer, Integer>(4, 0));
@SuppressWarnings("unchecked")
DataSet<Tuple1<Integer>> workInput = env.fromElements(
new Tuple1<Integer>(1),
new Tuple1<Integer>(2),
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -38,6 +38,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
@SuppressWarnings("serial")
public class CoGroupOperatorCollectionTest implements Serializable {
@Test
......@@ -68,12 +69,17 @@ public class CoGroupOperatorCollectionTest implements Serializable {
final RuntimeContext ctx = new RuntimeUDFContext("Test UDF", 4, 0);
{
SumCoGroup udf = new SumCoGroup();
List<Tuple2<String, Integer>> result = getCoGroupOperator(udf)
.executeOnCollections(input1, input2, ctx);
Assert.assertTrue(udf.isClosed);
SumCoGroup udf1 = new SumCoGroup();
SumCoGroup udf2 = new SumCoGroup();
List<Tuple2<String, Integer>> resultSafe = getCoGroupOperator(udf1)
.executeOnCollections(input1, input2, ctx, true);
List<Tuple2<String, Integer>> resultRegular = getCoGroupOperator(udf2)
.executeOnCollections(input1, input2, ctx, false);
Assert.assertTrue(udf1.isClosed);
Assert.assertTrue(udf2.isClosed);
Set<Tuple2<String, Integer>> expected = new HashSet<Tuple2<String, Integer>>(
Arrays.asList(new Tuple2Builder<String, Integer>()
.add("foo", 8)
......@@ -84,14 +90,21 @@ public class CoGroupOperatorCollectionTest implements Serializable {
)
);
Assert.assertEquals(expected, new HashSet(result));
Assert.assertEquals(expected, new HashSet<Tuple2<String, Integer>>(resultSafe));
Assert.assertEquals(expected, new HashSet<Tuple2<String, Integer>>(resultRegular));
}
{
List<Tuple2<String, Integer>> result = getCoGroupOperator(new SumCoGroup())
.executeOnCollections(Collections.EMPTY_LIST, Collections.EMPTY_LIST, ctx);
Assert.assertEquals(0, result.size());
List<Tuple2<String, Integer>> resultSafe = getCoGroupOperator(new SumCoGroup())
.executeOnCollections(Collections.<Tuple2<String, Integer>>emptyList(),
Collections.<Tuple2<String, Integer>>emptyList(), ctx, true);
List<Tuple2<String, Integer>> resultRegular = getCoGroupOperator(new SumCoGroup())
.executeOnCollections(Collections.<Tuple2<String, Integer>>emptyList(),
Collections.<Tuple2<String, Integer>>emptyList(), ctx, false);
Assert.assertEquals(0, resultSafe.size());
Assert.assertEquals(0, resultRegular.size());
}
} catch (Throwable t) {
t.printStackTrace();
......
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -79,15 +79,21 @@ public class GroupReduceOperatorTest implements java.io.Serializable {
Integer>("foo", 3), new Tuple2<String, Integer>("bar", 2), new Tuple2<String,
Integer>("bar", 4)));
List<Tuple2<String, Integer>> result = op.executeOnCollections(input, null);
Set<Tuple2<String, Integer>> resultSet = new HashSet<Tuple2<String, Integer>>(result);
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, null, true);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, null, false);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<Tuple2<String, Integer>>(resultRegular);
Set<Tuple2<String, Integer>> expectedResult = new HashSet<Tuple2<String,
Integer>>(asList(new Tuple2<String, Integer>("foo", 4), new Tuple2<String,
Integer>("bar", 6)));
assertEquals(expectedResult, resultSet);
} catch (Exception e) {
assertEquals(expectedResult, resultSetMutableSafe);
assertEquals(expectedResult, resultSetRegular);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
......@@ -149,15 +155,19 @@ public class GroupReduceOperatorTest implements java.io.Serializable {
Integer>("foo", 3), new Tuple2<String, Integer>("bar", 2), new Tuple2<String,
Integer>("bar", 4)));
List<Tuple2<String, Integer>> result = op.executeOnCollections(input,
new RuntimeUDFContext(taskName, 1, 0));
Set<Tuple2<String, Integer>> resultSet = new HashSet<Tuple2<String, Integer>>(result);
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), true);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), false);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<Tuple2<String, Integer>>(resultRegular);
Set<Tuple2<String, Integer>> expectedResult = new HashSet<Tuple2<String,
Integer>>(asList(new Tuple2<String, Integer>("foo", 4), new Tuple2<String,
Integer>("bar", 6)));
assertEquals(expectedResult, resultSet);
assertEquals(expectedResult, resultSetMutableSafe);
assertEquals(expectedResult, resultSetRegular);
assertTrue(opened.get());
assertTrue(closed.get());
......
......@@ -16,12 +16,12 @@
* limitations under the License.
*/
package org.apache.flink.api.java.operators.base;
package org.apache.flink.api.common.operators.base;
import static org.junit.Assert.*;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.base.JoinOperatorBase;
import org.apache.flink.api.java.tuple.Tuple2;
......@@ -31,21 +31,23 @@ import org.apache.flink.util.Collector;
import org.junit.Test;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@SuppressWarnings({ "unchecked", "serial" })
public class JoinOperatorBaseTest implements Serializable {
@Test
public void testTupleBaseJoiner(){
final FlatJoinFunction<Tuple3<String, Double, Integer>, Tuple2<Integer,
String>, Tuple2<Double, String>> joiner = new FlatJoinFunction() {
final FlatJoinFunction<Tuple3<String, Double, Integer>, Tuple2<Integer, String>, Tuple2<Double, String>> joiner =
new FlatJoinFunction<Tuple3<String, Double, Integer>, Tuple2<Integer, String>, Tuple2<Double, String>>()
{
@Override
public void join(Object first, Object second, Collector out) throws Exception {
public void join(Tuple3<String, Double, Integer> first, Tuple2<Integer, String> second, Collector<Tuple2<Double, String>> out) {
Tuple3<String, Double, Integer> fst = (Tuple3<String, Double, Integer>)first;
Tuple2<Integer, String> snd = (Tuple2<Integer, String>)second;
......@@ -99,18 +101,15 @@ public class JoinOperatorBaseTest implements Serializable {
));
try {
Method executeOnCollections = base.getClass().getDeclaredMethod("executeOnCollections", List.class,
List.class, RuntimeContext.class);
executeOnCollections.setAccessible(true);
List<Tuple2<Double, String>> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0), true);
List<Tuple2<Double, String>> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0), false);
Object result = executeOnCollections.invoke(base, inputData1, inputData2, null);
assertEquals(expected, new HashSet<Tuple2<Double, String>>((List<Tuple2<Double, String>>)result));
} catch (Exception e) {
assertEquals(expected, new HashSet<Tuple2<Double, String>>(resultSafe));
assertEquals(expected, new HashSet<Tuple2<Double, String>>(resultRegular));
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
}
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......@@ -68,15 +68,20 @@ public class ReduceOperatorTest implements java.io.Serializable {
Integer>("foo", 3), new Tuple2<String, Integer>("bar", 2), new Tuple2<String,
Integer>("bar", 4)));
List<Tuple2<String, Integer>> result = op.executeOnCollections(input, null);
Set<Tuple2<String, Integer>> resultSet = new HashSet<Tuple2<String, Integer>>(result);
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, null, true);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, null, false);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<Tuple2<String, Integer>>(resultRegular);
Set<Tuple2<String, Integer>> expectedResult = new HashSet<Tuple2<String,
Integer>>(asList(new Tuple2<String, Integer>("foo", 4), new Tuple2<String,
Integer>("bar", 6)));
assertEquals(expectedResult, resultSet);
} catch (Exception e) {
assertEquals(expectedResult, resultSetMutableSafe);
assertEquals(expectedResult, resultSetRegular);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
......@@ -127,20 +132,23 @@ public class ReduceOperatorTest implements java.io.Serializable {
Integer>("foo", 3), new Tuple2<String, Integer>("bar", 2), new Tuple2<String,
Integer>("bar", 4)));
List<Tuple2<String, Integer>> result = op.executeOnCollections(input,
new RuntimeUDFContext(taskName, 1, 0));
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), true);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0), false);
Set<Tuple2<String, Integer>> resultSet = new HashSet<Tuple2<String, Integer>>(result);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<Tuple2<String, Integer>>(resultRegular);
Set<Tuple2<String, Integer>> expectedResult = new HashSet<Tuple2<String,
Integer>>(asList(new Tuple2<String, Integer>("foo", 4), new Tuple2<String,
Integer>("bar", 6)));
assertEquals(expectedResult, resultSet);
assertEquals(expectedResult, resultSetMutableSafe);
assertEquals(expectedResult, resultSetRegular);
assertTrue(opened.get());
assertTrue(closed.get());
} catch (Exception e) {
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册