提交 8bf9416d 编写于 作者: G Gyula Fora

[streaming] Streaming API grouping rework to use batch api Keys

上级 1d019b9b
......@@ -33,6 +33,7 @@ import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
......@@ -66,8 +67,7 @@ import org.apache.flink.streaming.partitioner.DistributePartitioner;
import org.apache.flink.streaming.partitioner.FieldsPartitioner;
import org.apache.flink.streaming.partitioner.ShufflePartitioner;
import org.apache.flink.streaming.partitioner.StreamPartitioner;
import org.apache.flink.streaming.util.keys.FieldsKeySelector;
import org.apache.flink.streaming.util.keys.PojoKeySelector;
import org.apache.flink.streaming.util.keys.KeySelectorUtil;
/**
* A DataStream represents a stream of elements of the same type. A DataStream
......@@ -245,9 +245,11 @@ public class DataStream<OUT> {
* @return The grouped {@link DataStream}
*/
public GroupedDataStream<OUT> groupBy(int... fields) {
return groupBy(FieldsKeySelector.getSelector(getType(), fields));
if (getType() instanceof BasicArrayTypeInfo || getType() instanceof PrimitiveArrayTypeInfo) {
return groupBy(new KeySelectorUtil.ArrayKeySelector<OUT>(fields));
} else {
return groupBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
}
}
/**
......@@ -264,7 +266,7 @@ public class DataStream<OUT> {
**/
public GroupedDataStream<OUT> groupBy(String... fields) {
return groupBy(new PojoKeySelector<OUT>(getType(), fields));
return groupBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
}
......@@ -282,6 +284,11 @@ public class DataStream<OUT> {
return new GroupedDataStream<OUT>(this, clean(keySelector));
}
private GroupedDataStream<OUT> groupBy(Keys<OUT> keys) {
return new GroupedDataStream<OUT>(this, clean(KeySelectorUtil.getSelectorForKeys(keys,
getType())));
}
/**
* Sets the partitioning of the {@link DataStream} so that the output is
* partitioned by the selected fields. This setting only effects the how the
......@@ -293,9 +300,11 @@ public class DataStream<OUT> {
* @return The DataStream with fields partitioning set.
*/
public DataStream<OUT> partitionBy(int... fields) {
return setConnectionType(new FieldsPartitioner<OUT>(FieldsKeySelector.getSelector(
getType(), fields)));
if (getType() instanceof BasicArrayTypeInfo || getType() instanceof PrimitiveArrayTypeInfo) {
return partitionBy(new KeySelectorUtil.ArrayKeySelector<OUT>(fields));
} else {
return partitionBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
}
}
/**
......@@ -309,9 +318,11 @@ public class DataStream<OUT> {
* @return The DataStream with fields partitioning set.
*/
public DataStream<OUT> partitionBy(String... fields) {
return partitionBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
}
return setConnectionType(new FieldsPartitioner<OUT>(new PojoKeySelector<OUT>(getType(),
fields)));
private DataStream<OUT> partitionBy(Keys<OUT> keys) {
return partitionBy(KeySelectorUtil.getSelectorForKeys(keys, getType()));
}
/**
......@@ -411,7 +422,7 @@ public class DataStream<OUT> {
* the data stream that will be fed back and used as the input for the
* iteration head. A common usage pattern for streaming iterations is to use
* output splitting to send a part of the closing data stream to the head.
* Refer to {@link SingleOutputStreamOperator#split(OutputSelector)} for
* Refer to {@link SingleOutputStreamOperator#split(outputSelector)} for
* more information.
* <p>
* The iteration edge will be partitioned the same way as the first input of
......@@ -549,7 +560,7 @@ public class DataStream<OUT> {
* {@link StreamCrossOperator#onWindow} should be called to define the
* window.
* <p>
* Call {@link StreamCrossOperator.CrossWindow#with(CrossFunction)} to
* Call {@link StreamCrossOperator.CrossWindow#with(crossFunction)} to
* define a custom cross function.
*
* @param dataStreamToCross
......@@ -572,7 +583,7 @@ public class DataStream<OUT> {
* window, and then the {@link StreamJoinOperator.JoinWindow#where} and
* {@link StreamJoinOperator.JoinPredicate#equalTo} can be used to define
* the join keys.</p> The user can also use the
* {@link StreamJoinOperator.JoinedStream#with(JoinFunction)} to apply
* {@link StreamJoinOperator.JoinedStream#with(joinFunction)} to apply
* custom join function.
*
* @param other
......
......@@ -21,14 +21,14 @@ package org.apache.flink.streaming.api.datastream;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.api.function.co.JoinWindowFunction;
import org.apache.flink.streaming.api.invokable.operator.co.CoWindowInvokable;
import org.apache.flink.streaming.util.keys.FieldsKeySelector;
import org.apache.flink.streaming.util.keys.PojoKeySelector;
import org.apache.flink.streaming.util.keys.KeySelectorUtil;
public class StreamJoinOperator<I1, I2> extends
TemporalOperator<I1, I2, StreamJoinOperator.JoinWindow<I1, I2>> {
......@@ -45,9 +45,11 @@ public class StreamJoinOperator<I1, I2> extends
public static class JoinWindow<I1, I2> {
private StreamJoinOperator<I1, I2> op;
private TypeInformation<I1> type1;
private JoinWindow(StreamJoinOperator<I1, I2> operator) {
this.op = operator;
this.type1 = op.input1.getType();
}
/**
......@@ -64,8 +66,8 @@ public class StreamJoinOperator<I1, I2> extends
* {@link JoinPredicate#equalTo} to continue the Join.
*/
public JoinPredicate<I1, I2> where(int... fields) {
return new JoinPredicate<I1, I2>(op, FieldsKeySelector.getSelector(op.input1.getType(),
fields));
return new JoinPredicate<I1, I2>(op, KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys<I1>(fields, type1), type1));
}
/**
......@@ -81,8 +83,8 @@ public class StreamJoinOperator<I1, I2> extends
* {@link JoinPredicate#equalTo} to continue the Join.
*/
public JoinPredicate<I1, I2> where(String... fields) {
return new JoinPredicate<I1, I2>(op, new PojoKeySelector<I1>(op.input1.getType(),
fields));
return new JoinPredicate<I1, I2>(op, KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys<I1>(fields, type1), type1));
}
/**
......@@ -114,13 +116,15 @@ public class StreamJoinOperator<I1, I2> extends
*/
public static class JoinPredicate<I1, I2> {
public StreamJoinOperator<I1, I2> op;
public KeySelector<I1, ?> keys1;
public KeySelector<I2, ?> keys2;
private StreamJoinOperator<I1, I2> op;
private KeySelector<I1, ?> keys1;
private KeySelector<I2, ?> keys2;
private TypeInformation<I2> type2;
private JoinPredicate(StreamJoinOperator<I1, I2> operator, KeySelector<I1, ?> keys1) {
this.op = operator;
this.keys1 = keys1;
this.type2 = op.input2.getType();
}
/**
......@@ -138,7 +142,8 @@ public class StreamJoinOperator<I1, I2> extends
* apply a custom wrapping
*/
public JoinedStream<I1, I2> equalTo(int... fields) {
keys2 = FieldsKeySelector.getSelector(op.input2.getType(), fields);
keys2 = KeySelectorUtil.getSelectorForKeys(new Keys.ExpressionKeys<I2>(fields, type2),
type2);
return createJoinOperator();
}
......@@ -156,7 +161,8 @@ public class StreamJoinOperator<I1, I2> extends
* apply a custom wrapping
*/
public JoinedStream<I1, I2> equalTo(String... fields) {
this.keys2 = new PojoKeySelector<I2>(op.input2.getType(), fields);
this.keys2 = KeySelectorUtil.getSelectorForKeys(new Keys.ExpressionKeys<I2>(fields,
type2), type2);
return createJoinOperator();
}
......
......@@ -41,8 +41,12 @@ public class JoinWindowFunction<IN1, IN2, OUT> implements CoWindowFunction<IN1,
@Override
public void coWindow(List<IN1> first, List<IN2> second, Collector<OUT> out) throws Exception {
for (IN1 item1 : first) {
Object key1 = keySelector1.getKey(item1);
for (IN2 item2 : second) {
if (keySelector1.getKey(item1).equals(keySelector2.getKey(item2))) {
Object key2 = keySelector2.getKey(item2);
if (key1.equals(key2)) {
out.collect(joinFunction.join(item1, item2));
}
}
......
......@@ -59,7 +59,7 @@ public class ClusterUtil {
exec = new LocalFlinkMiniCluster(configuration, true);
ActorRef jobClient = exec.getJobClient();
JobClient.submitJobAndWait(jobGraph, false, jobClient, exec.timeout());
JobClient.submitJobAndWait(jobGraph, true, jobClient, exec.timeout());
} catch (Exception e) {
throw e;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util.keys;
import java.lang.reflect.Array;
import org.apache.flink.api.java.tuple.Tuple;
public class ArrayKeySelector<IN> extends FieldsKeySelector<IN> {
private static final long serialVersionUID = 1L;
public ArrayKeySelector(int... fields) {
super(fields);
}
@Override
public Object getKey(IN value) throws Exception {
if (simpleKey) {
return Array.get(value, keyFields[0]);
} else {
int c = 0;
for (int pos : keyFields) {
((Tuple) key).setField(Array.get(value, pos), c);
c++;
}
return key;
}
}
}
......@@ -17,10 +17,13 @@
package org.apache.flink.streaming.util.keys;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import java.lang.reflect.Array;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple10;
......@@ -48,56 +51,76 @@ import org.apache.flink.api.java.tuple.Tuple7;
import org.apache.flink.api.java.tuple.Tuple8;
import org.apache.flink.api.java.tuple.Tuple9;
public abstract class FieldsKeySelector<IN> implements KeySelector<IN, Object> {
public class KeySelectorUtil {
public static Class<?>[] tupleClasses = new Class[] { Tuple1.class, Tuple2.class, Tuple3.class,
Tuple4.class, Tuple5.class, Tuple6.class, Tuple7.class, Tuple8.class, Tuple9.class,
Tuple10.class, Tuple11.class, Tuple12.class, Tuple13.class, Tuple14.class,
Tuple15.class, Tuple16.class, Tuple17.class, Tuple18.class, Tuple19.class,
Tuple20.class, Tuple21.class, Tuple22.class, Tuple23.class, Tuple24.class,
Tuple25.class };
public static <X> KeySelector<X, ?> getSelectorForKeys(Keys<X> keys, TypeInformation<X> typeInfo) {
int[] logicalKeyPositions = keys.computeLogicalKeyPositions();
int keyLength = logicalKeyPositions.length;
boolean[] orders = new boolean[keyLength];
TypeComparator<X> comparator = ((CompositeType<X>) typeInfo).createComparator(
logicalKeyPositions, orders, 0);
return new ComparableKeySelector<X>(comparator, keyLength);
}
private static final long serialVersionUID = 1L;
public static class ComparableKeySelector<IN> implements KeySelector<IN, Tuple> {
protected int[] keyFields;
protected Object key;
protected boolean simpleKey;
private static final long serialVersionUID = 1L;
@SuppressWarnings("unchecked")
public static Class<? extends Tuple>[] tupleClasses = new Class[] { Tuple1.class, Tuple2.class,
Tuple3.class, Tuple4.class, Tuple5.class, Tuple6.class, Tuple7.class, Tuple8.class,
Tuple9.class, Tuple10.class, Tuple11.class, Tuple12.class, Tuple13.class,
Tuple14.class, Tuple15.class, Tuple16.class, Tuple17.class, Tuple18.class,
Tuple19.class, Tuple20.class, Tuple21.class, Tuple22.class, Tuple23.class,
Tuple24.class, Tuple25.class };
private TypeComparator<IN> comparator;
private int keyLength;
private Object[] keyArray;
private Tuple key;
public FieldsKeySelector(int... fields) {
this.keyFields = fields;
this.simpleKey = fields.length == 1;
for (int i : fields) {
if (i < 0) {
throw new RuntimeException("Grouping fields must be non-negative");
public ComparableKeySelector(TypeComparator<IN> comparator, int keyLength) {
this.comparator = comparator;
this.keyLength = keyLength;
keyArray = new Object[keyLength];
try {
key = (Tuple) tupleClasses[keyLength - 1].newInstance();
} catch (Exception e) {
}
}
try {
key = tupleClasses[fields.length - 1].newInstance();
} catch (Exception e) {
throw new RuntimeException(e.getMessage());
@Override
public Tuple getKey(IN value) throws Exception {
comparator.extractKeys(value, keyArray, 0);
for (int i = 0; i < keyLength; i++) {
key.setField(keyArray[i], i);
}
return key;
}
}
public static <R> KeySelector<R, ?> getSelector(TypeInformation<R> type, int... fields) {
if (type.isTupleType()) {
return new TupleKeySelector<R>(fields);
} else if (type instanceof BasicArrayTypeInfo || type instanceof PrimitiveArrayTypeInfo) {
return new ArrayKeySelector<R>(fields);
} else {
if (fields.length > 1) {
throw new RuntimeException(
"For non-tuple types use single field 0 or KeyExctractor for grouping");
} else if (fields[0] > 0) {
throw new RuntimeException(
"For simple objects grouping only allowed on the first field");
} else {
return new ObjectKeySelector<R>();
public static class ArrayKeySelector<IN> implements KeySelector<IN, Tuple> {
private static final long serialVersionUID = 1L;
Tuple key;
int[] fields;
public ArrayKeySelector(int... fields) {
this.fields = fields;
try {
key = (Tuple) tupleClasses[fields.length - 1].newInstance();
} catch (Exception e) {
}
}
}
@Override
public Tuple getKey(IN value) throws Exception {
for (int i = 0; i < fields.length; i++) {
int pos = fields[i];
key.setField(Array.get(value, fields[pos]), i);
}
return key;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util.keys;
import org.apache.flink.api.java.functions.KeySelector;
public class ObjectKeySelector<IN> implements KeySelector<IN, IN> {
private static final long serialVersionUID = 1L;
@Override
public IN getKey(IN value) throws Exception {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util.keys;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.runtime.PojoComparator;
public class PojoKeySelector<IN> extends FieldsKeySelector<IN> {
private static final long serialVersionUID = 1L;
PojoComparator<IN> comparator;
public PojoKeySelector(TypeInformation<IN> type, String... fields) {
super(new int[removeDuplicates(fields).length]);
if (!(type instanceof CompositeType<?>)) {
throw new IllegalArgumentException(
"Key expressions are only supported on POJO types and Tuples. "
+ "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
}
CompositeType<IN> cType = (CompositeType<IN>) type;
String[] keyFields = removeDuplicates(fields);
int numOfKeys = keyFields.length;
List<FlatFieldDescriptor> fieldDescriptors = new ArrayList<FlatFieldDescriptor>();
for (String field : keyFields) {
cType.getKey(field, 0, fieldDescriptors);
}
int[] logicalKeyPositions = new int[numOfKeys];
boolean[] orders = new boolean[numOfKeys];
for (int i = 0; i < numOfKeys; i++) {
logicalKeyPositions[i] = fieldDescriptors.get(i).getPosition();
}
if (cType instanceof PojoTypeInfo) {
comparator = (PojoComparator<IN>) cType
.createComparator(logicalKeyPositions, orders, 0);
} else {
throw new IllegalArgumentException(
"Key expressions are only supported on POJO types. "
+ "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
}
}
@Override
public Object getKey(IN value) throws Exception {
Field[] keyFields = comparator.getKeyFields();
if (simpleKey) {
return comparator.accessField(keyFields[0], value);
} else {
int c = 0;
for (Field field : keyFields) {
((Tuple) key).setField(comparator.accessField(field, value), c);
c++;
}
}
return key;
}
private static String[] removeDuplicates(String[] in) {
List<String> ret = new LinkedList<String>();
for (String el : in) {
if (!ret.contains(el)) {
ret.add(el);
}
}
return ret.toArray(new String[ret.size()]);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util.keys;
import org.apache.flink.api.java.tuple.Tuple;
public class TupleKeySelector<IN> extends FieldsKeySelector<IN> {
private static final long serialVersionUID = 1L;
public TupleKeySelector(int... fields) {
super(fields);
}
@Override
public Object getKey(IN value) throws Exception {
if (simpleKey) {
return ((Tuple) value).getField(keyFields[0]);
} else {
int c = 0;
for (int pos : keyFields) {
((Tuple) key).setField(((Tuple) value).getField(pos), c);
c++;
}
return key;
}
}
}
......@@ -25,6 +25,8 @@ import java.util.List;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.Keys;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
......@@ -34,7 +36,7 @@ import org.apache.flink.streaming.api.function.aggregation.SumAggregator;
import org.apache.flink.streaming.api.invokable.operator.GroupedReduceInvokable;
import org.apache.flink.streaming.api.invokable.operator.StreamReduceInvokable;
import org.apache.flink.streaming.util.MockContext;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.apache.flink.streaming.util.keys.KeySelectorUtil;
import org.junit.Test;
public class AggregationFunctionTest {
......@@ -94,14 +96,14 @@ public class AggregationFunctionTest {
Integer.class, type1);
ReduceFunction<Integer> sumFunction0 = SumAggregator
.getSumFunction(0, Integer.class, type2);
ReduceFunction<Tuple2<Integer, Integer>> minFunction = ComparableAggregator
.getAggregator(1, type1, AggregationType.MIN);
ReduceFunction<Integer> minFunction0 = ComparableAggregator.getAggregator(0,
type2, AggregationType.MIN);
ReduceFunction<Tuple2<Integer, Integer>> maxFunction = ComparableAggregator
.getAggregator(1, type1, AggregationType.MAX);
ReduceFunction<Integer> maxFunction0 = ComparableAggregator.getAggregator(0,
type2, AggregationType.MAX);
ReduceFunction<Tuple2<Integer, Integer>> minFunction = ComparableAggregator.getAggregator(
1, type1, AggregationType.MIN);
ReduceFunction<Integer> minFunction0 = ComparableAggregator.getAggregator(0, type2,
AggregationType.MIN);
ReduceFunction<Tuple2<Integer, Integer>> maxFunction = ComparableAggregator.getAggregator(
1, type1, AggregationType.MAX);
ReduceFunction<Integer> maxFunction0 = ComparableAggregator.getAggregator(0, type2,
AggregationType.MAX);
List<Tuple2<Integer, Integer>> sumList = MockContext.createAndExecute(
new StreamReduceInvokable<Tuple2<Integer, Integer>>(sumFunction), getInputList());
......@@ -111,17 +113,24 @@ public class AggregationFunctionTest {
List<Tuple2<Integer, Integer>> maxList = MockContext.createAndExecute(
new StreamReduceInvokable<Tuple2<Integer, Integer>>(maxFunction), getInputList());
TypeInformation<Tuple2<Integer, Integer>> typeInfo = TypeExtractor
.getForObject(new Tuple2<Integer, Integer>(1, 1));
KeySelector<Tuple2<Integer, Integer>, ?> keySelector = KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys<Tuple2<Integer, Integer>>(new int[] { 0 }, typeInfo),
typeInfo);
List<Tuple2<Integer, Integer>> groupedSumList = MockContext.createAndExecute(
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(sumFunction,
new TupleKeySelector<Tuple2<Integer, Integer>>(0)), getInputList());
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(sumFunction, keySelector),
getInputList());
List<Tuple2<Integer, Integer>> groupedMinList = MockContext.createAndExecute(
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(minFunction,
new TupleKeySelector<Tuple2<Integer, Integer>>(0)), getInputList());
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(minFunction, keySelector),
getInputList());
List<Tuple2<Integer, Integer>> groupedMaxList = MockContext.createAndExecute(
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(maxFunction,
new TupleKeySelector<Tuple2<Integer, Integer>>(0)), getInputList());
new GroupedReduceInvokable<Tuple2<Integer, Integer>>(maxFunction, keySelector),
getInputList());
assertEquals(expectedSumList, sumList);
assertEquals(expectedMinList, minList);
......
......@@ -23,6 +23,7 @@ import java.io.Serializable;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.CrossFunction;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
......@@ -49,16 +50,16 @@ public class WindowCrossJoinTest implements Serializable {
env.setBufferTimeout(1);
ArrayList<Tuple2<Integer, String>> in1 = new ArrayList<Tuple2<Integer, String>>();
ArrayList<Integer> in2 = new ArrayList<Integer>();
ArrayList<Tuple1<Integer>> in2 = new ArrayList<Tuple1<Integer>>();
in1.add(new Tuple2<Integer, String>(10, "a"));
in1.add(new Tuple2<Integer, String>(20, "b"));
in1.add(new Tuple2<Integer, String>(20, "x"));
in1.add(new Tuple2<Integer, String>(0, "y"));
in2.add(0);
in2.add(5);
in2.add(20);
in2.add(new Tuple1<Integer>(0));
in2.add(new Tuple1<Integer>(5));
in2.add(new Tuple1<Integer>(20));
joinExpectedResults.add(new Tuple2<Tuple2<Integer, String>, Integer>(
new Tuple2<Integer, String>(20, "b"), 20));
......@@ -93,23 +94,24 @@ public class WindowCrossJoinTest implements Serializable {
new Tuple2<Integer, String>(0, "y"), 20));
DataStream<Tuple2<Integer, String>> inStream1 = env.fromCollection(in1);
DataStream<Integer> inStream2 = env.fromCollection(in2);
DataStream<Tuple1<Integer>> inStream2 = env.fromCollection(in2);
inStream1.join(inStream2).onWindow(1000, 1000, new MyTimestamp1(), new MyTimestamp2())
.where(0).equalTo(0).addSink(new JoinResultSink());
inStream1.cross(inStream2).onWindow(1000, 1000, new MyTimestamp1(), new MyTimestamp2())
.with(new CrossFunction<Tuple2<Integer,String>, Integer, Tuple2<Tuple2<Integer,String>, Integer>>() {
inStream1
.cross(inStream2)
.onWindow(1000, 1000, new MyTimestamp1(), new MyTimestamp2())
.with(new CrossFunction<Tuple2<Integer, String>, Tuple1<Integer>, Tuple2<Tuple2<Integer, String>, Tuple1<Integer>>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Tuple2<Integer, String>, Integer> cross(
Tuple2<Integer, String> val1, Integer val2) throws Exception {
return new Tuple2<Tuple2<Integer,String>, Integer>(val1, val2);
public Tuple2<Tuple2<Integer, String>, Tuple1<Integer>> cross(
Tuple2<Integer, String> val1, Tuple1<Integer> val2) throws Exception {
return new Tuple2<Tuple2<Integer, String>, Tuple1<Integer>>(val1, val2);
}
})
.addSink(new CrossResultSink());
}).addSink(new CrossResultSink());
env.execute();
......@@ -131,11 +133,11 @@ public class WindowCrossJoinTest implements Serializable {
}
}
private static class MyTimestamp2 implements TimeStamp<Integer> {
private static class MyTimestamp2 implements TimeStamp<Tuple1<Integer>> {
private static final long serialVersionUID = 1L;
@Override
public long getTimestamp(Integer value) {
public long getTimestamp(Tuple1<Integer> value) {
return 101L;
}
......@@ -146,22 +148,22 @@ public class WindowCrossJoinTest implements Serializable {
}
private static class JoinResultSink implements
SinkFunction<Tuple2<Tuple2<Integer, String>, Integer>> {
SinkFunction<Tuple2<Tuple2<Integer, String>, Tuple1<Integer>>> {
private static final long serialVersionUID = 1L;
@Override
public void invoke(Tuple2<Tuple2<Integer, String>, Integer> value) {
joinResults.add(value);
public void invoke(Tuple2<Tuple2<Integer, String>, Tuple1<Integer>> value) {
joinResults.add(new Tuple2<Tuple2<Integer, String>, Integer>(value.f0, value.f1.f0));
}
}
private static class CrossResultSink implements
SinkFunction<Tuple2<Tuple2<Integer, String>, Integer>> {
SinkFunction<Tuple2<Tuple2<Integer, String>, Tuple1<Integer>>> {
private static final long serialVersionUID = 1L;
@Override
public void invoke(Tuple2<Tuple2<Integer, String>, Integer> value) {
crossResults.add(value);
public void invoke(Tuple2<Tuple2<Integer, String>, Tuple1<Integer>> value) {
crossResults.add(new Tuple2<Tuple2<Integer, String>, Integer>(value.f0, value.f1.f0));
}
}
}
......@@ -23,15 +23,35 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.function.co.CoReduceFunction;
import org.apache.flink.streaming.api.invokable.operator.co.CoGroupedBatchReduceInvokable;
import org.apache.flink.streaming.util.MockCoContext;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Test;
public class CoGroupedBatchReduceTest {
KeySelector<Tuple2<String, String>, ?> keySelector1 = new KeySelector<Tuple2<String, String>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<String, String> value) throws Exception {
return value.f0;
}
};
KeySelector<Tuple2<String, Integer>, ?> keySelector2 = new KeySelector<Tuple2<String, Integer>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<String, Integer> value) throws Exception {
return value.f0;
}
};
private static class MyCoReduceFunction implements
CoReduceFunction<Tuple2<String, Integer>, Tuple2<String, String>, String> {
private static final long serialVersionUID = 1L;
......@@ -59,7 +79,6 @@ public class CoGroupedBatchReduceTest {
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void coGroupedBatchReduceTest1() {
......@@ -96,8 +115,7 @@ public class CoGroupedBatchReduceTest {
expected.add("h");
CoGroupedBatchReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String> invokable = new CoGroupedBatchReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String>(
new MyCoReduceFunction(), 4L, 3L, 4L, 3L, new TupleKeySelector(0),
new TupleKeySelector(0));
new MyCoReduceFunction(), 4L, 3L, 4L, 3L, keySelector2, keySelector1);
List<String> result = MockCoContext.createAndExecute(invokable, inputs1, inputs2);
......@@ -106,7 +124,6 @@ public class CoGroupedBatchReduceTest {
assertEquals(expected, result);
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void coGroupedBatchReduceTest2() {
......@@ -143,8 +160,7 @@ public class CoGroupedBatchReduceTest {
expected.add("fh");
CoGroupedBatchReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String> invokable = new CoGroupedBatchReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String>(
new MyCoReduceFunction(), 4L, 3L, 2L, 2L, new TupleKeySelector(0),
new TupleKeySelector(0));
new MyCoReduceFunction(), 4L, 3L, 2L, 2L, keySelector2, keySelector1);
List<String> result = MockCoContext.createAndExecute(invokable, inputs1, inputs2);
......
......@@ -22,12 +22,12 @@ import static org.junit.Assert.assertEquals;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.streaming.api.function.co.CoReduceFunction;
import org.apache.flink.streaming.api.invokable.operator.co.CoGroupedReduceInvokable;
import org.apache.flink.streaming.util.MockCoContext;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Test;
public class CoGroupedReduceTest {
......@@ -59,7 +59,7 @@ public class CoGroupedReduceTest {
}
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings("unchecked")
@Test
public void coGroupedReduceTest() {
Tuple3<String, String, String> word1 = new Tuple3<String, String, String>("a", "word1", "b");
......@@ -71,8 +71,38 @@ public class CoGroupedReduceTest {
Tuple2<Integer, Integer> int4 = new Tuple2<Integer, Integer>(2, 4);
Tuple2<Integer, Integer> int5 = new Tuple2<Integer, Integer>(1, 5);
KeySelector<Tuple3<String, String, String>, ?> keySelector0 = new KeySelector<Tuple3<String, String, String>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple3<String, String, String> value) throws Exception {
return value.f0;
}
};
KeySelector<Tuple2<Integer, Integer>, ?> keySelector1 = new KeySelector<Tuple2<Integer, Integer>, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Integer getKey(Tuple2<Integer, Integer> value) throws Exception {
return value.f0;
}
};
KeySelector<Tuple3<String, String, String>, ?> keySelector2 = new KeySelector<Tuple3<String, String, String>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple3<String, String, String> value) throws Exception {
return value.f2;
}
};
CoGroupedReduceInvokable<Tuple3<String, String, String>, Tuple2<Integer, Integer>, String> invokable = new CoGroupedReduceInvokable<Tuple3<String, String, String>, Tuple2<Integer, Integer>, String>(
new MyCoReduceFunction(), new TupleKeySelector(0), new TupleKeySelector(0));
new MyCoReduceFunction(), keySelector0, keySelector1);
List<String> expected = Arrays.asList("word1", "1", "word2", "2", "word1word3", "3", "5",
"7");
......@@ -83,12 +113,12 @@ public class CoGroupedReduceTest {
assertEquals(expected, actualList);
invokable = new CoGroupedReduceInvokable<Tuple3<String, String, String>, Tuple2<Integer, Integer>, String>(
new MyCoReduceFunction(), new TupleKeySelector(2), new TupleKeySelector(0));
new MyCoReduceFunction(), keySelector2, keySelector1);
expected = Arrays.asList("word1", "1", "word2", "2", "word2word3", "3", "5", "7");
actualList = MockCoContext.createAndExecute(invokable,
Arrays.asList(word1, word2, word3), Arrays.asList(int1, int2, int3, int4, int5));
actualList = MockCoContext.createAndExecute(invokable, Arrays.asList(word1, word2, word3),
Arrays.asList(int1, int2, int3, int4, int5));
assertEquals(expected, actualList);
}
......
......@@ -25,16 +25,36 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.function.co.CoReduceFunction;
import org.apache.flink.streaming.api.invokable.operator.co.CoGroupedWindowReduceInvokable;
import org.apache.flink.streaming.api.invokable.util.TimeStamp;
import org.apache.flink.streaming.util.MockCoContext;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Test;
public class CoGroupedWindowReduceTest {
KeySelector<Tuple2<String, Integer>, ?> keySelector0 = new KeySelector<Tuple2<String, Integer>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<String, Integer> value) throws Exception {
return value.f0;
}
};
KeySelector<Tuple2<String, String>, ?> keySelector1 = new KeySelector<Tuple2<String, String>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<String, String> value) throws Exception {
return value.f0;
}
};
private static class MyCoReduceFunction implements
CoReduceFunction<Tuple2<String, Integer>, Tuple2<String, String>, String> {
private static final long serialVersionUID = 1L;
......@@ -85,7 +105,6 @@ public class CoGroupedWindowReduceTest {
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void coGroupedWindowReduceTest1() {
......@@ -125,9 +144,9 @@ public class CoGroupedWindowReduceTest {
expected.add("i");
CoGroupedWindowReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String> invokable = new CoGroupedWindowReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String>(
new MyCoReduceFunction(), 4L, 3L, 4L, 3L, new TupleKeySelector(0),
new TupleKeySelector( 0), new MyTimeStamp<Tuple2<String, Integer>>(
timestamps1), new MyTimeStamp<Tuple2<String, String>>(timestamps2));
new MyCoReduceFunction(), 4L, 3L, 4L, 3L, keySelector0, keySelector1,
new MyTimeStamp<Tuple2<String, Integer>>(timestamps1),
new MyTimeStamp<Tuple2<String, String>>(timestamps2));
List<String> result = MockCoContext.createAndExecute(invokable, inputs1, inputs2);
......@@ -136,7 +155,6 @@ public class CoGroupedWindowReduceTest {
assertEquals(expected, result);
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void coGroupedWindowReduceTest2() {
......@@ -178,9 +196,9 @@ public class CoGroupedWindowReduceTest {
expected.add("fh");
CoGroupedWindowReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String> invokable = new CoGroupedWindowReduceInvokable<Tuple2<String, Integer>, Tuple2<String, String>, String>(
new MyCoReduceFunction(), 4L, 3L, 2L, 2L, new TupleKeySelector( 0),
new TupleKeySelector( 0), new MyTimeStamp<Tuple2<String, Integer>>(
timestamps1), new MyTimeStamp<Tuple2<String, String>>(timestamps2));
new MyCoReduceFunction(), 4L, 3L, 2L, 2L, keySelector0, keySelector1,
new MyTimeStamp<Tuple2<String, Integer>>(timestamps1),
new MyTimeStamp<Tuple2<String, String>>(timestamps2));
List<String> result = MockCoContext.createAndExecute(invokable, inputs1, inputs2);
......
......@@ -23,8 +23,8 @@ import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.util.MockContext;
import org.apache.flink.streaming.util.keys.ObjectKeySelector;
import org.junit.Test;
public class GroupedReduceInvokableTest {
......@@ -43,7 +43,15 @@ public class GroupedReduceInvokableTest {
@Test
public void test() {
GroupedReduceInvokable<Integer> invokable1 = new GroupedReduceInvokable<Integer>(
new MyReducer(), new ObjectKeySelector<Integer>());
new MyReducer(), new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
});
List<Integer> expected = Arrays.asList(1, 2, 2, 4, 3);
List<Integer> actual = MockContext.createAndExecute(invokable1,
......@@ -51,5 +59,4 @@ public class GroupedReduceInvokableTest {
assertEquals(expected, actual);
}
}
......@@ -40,11 +40,20 @@ import org.apache.flink.streaming.api.windowing.policy.TimeTriggerPolicy;
import org.apache.flink.streaming.api.windowing.policy.TriggerPolicy;
import org.apache.flink.streaming.api.windowing.policy.TumblingEvictionPolicy;
import org.apache.flink.streaming.util.MockContext;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Test;
public class GroupedWindowInvokableTest {
KeySelector<Tuple2<Integer, String>, ?> keySelector = new KeySelector<Tuple2<Integer, String>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple2<Integer, String> value) throws Exception {
return value.f1;
}
};
/**
* Tests that illegal arguments result in failure. The following cases are
* tested: 1) having no trigger 2) having no eviction 3) having neither
......@@ -162,7 +171,7 @@ public class GroupedWindowInvokableTest {
expectedDistributedEviction.add(3);
expectedDistributedEviction.add(3);
expectedDistributedEviction.add(15);
List<Integer> expectedCentralEviction = new ArrayList<Integer>();
expectedCentralEviction.add(2);
expectedCentralEviction.add(5);
......@@ -173,7 +182,7 @@ public class GroupedWindowInvokableTest {
expectedCentralEviction.add(5);
expectedCentralEviction.add(1);
expectedCentralEviction.add(5);
LinkedList<CloneableTriggerPolicy<Integer>> triggers = new LinkedList<CloneableTriggerPolicy<Integer>>();
// Trigger on every 2nd element, but the first time after the 3rd
triggers.add(new CountTriggerPolicy<Integer>(2, -1));
......@@ -185,7 +194,7 @@ public class GroupedWindowInvokableTest {
LinkedList<TriggerPolicy<Integer>> centralTriggers = new LinkedList<TriggerPolicy<Integer>>();
ReduceFunction<Integer> reduceFunction=new ReduceFunction<Integer>() {
ReduceFunction<Integer> reduceFunction = new ReduceFunction<Integer>() {
private static final long serialVersionUID = 1L;
@Override
......@@ -193,8 +202,8 @@ public class GroupedWindowInvokableTest {
return value1 + value2;
}
};
KeySelector<Integer, Integer> keySelector=new KeySelector<Integer, Integer>() {
KeySelector<Integer, Integer> keySelector = new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = 1L;
@Override
......@@ -202,7 +211,7 @@ public class GroupedWindowInvokableTest {
return value;
}
};
GroupedWindowInvokable<Integer, Integer> invokable = new GroupedWindowInvokable<Integer, Integer>(
reduceFunction, keySelector, triggers, evictions, centralTriggers, null);
......@@ -213,18 +222,19 @@ public class GroupedWindowInvokableTest {
actual.add(current);
}
assertEquals(new HashSet<Integer>(expectedDistributedEviction), new HashSet<Integer>(actual));
assertEquals(new HashSet<Integer>(expectedDistributedEviction),
new HashSet<Integer>(actual));
assertEquals(expectedDistributedEviction.size(), actual.size());
//Run test with central eviction
// Run test with central eviction
triggers.clear();
centralTriggers.add(new CountTriggerPolicy<Integer>(2, -1));
LinkedList<EvictionPolicy<Integer>> centralEvictions = new LinkedList<EvictionPolicy<Integer>>();
centralEvictions.add(new CountEvictionPolicy<Integer>(2, 2, -1));
invokable = new GroupedWindowInvokable<Integer, Integer>(
reduceFunction, keySelector, triggers, null, centralTriggers,centralEvictions);
invokable = new GroupedWindowInvokable<Integer, Integer>(reduceFunction, keySelector,
triggers, null, centralTriggers, centralEvictions);
result = MockContext.createAndExecute(invokable, inputs);
actual = new LinkedList<Integer>();
for (Integer current : result) {
......@@ -279,8 +289,7 @@ public class GroupedWindowInvokableTest {
return value2;
}
}
}, new TupleKeySelector<Tuple2<Integer, String>>(1), triggers, evictions,
centralTriggers, null);
}, keySelector, triggers, evictions, centralTriggers, null);
List<Tuple2<Integer, String>> result = MockContext.createAndExecute(invokable2, inputs2);
......@@ -387,8 +396,7 @@ public class GroupedWindowInvokableTest {
LinkedList<CloneableTriggerPolicy<Tuple2<Integer, String>>> distributedTriggers = new LinkedList<CloneableTriggerPolicy<Tuple2<Integer, String>>>();
GroupedWindowInvokable<Tuple2<Integer, String>, Tuple2<Integer, String>> invokable = new GroupedWindowInvokable<Tuple2<Integer, String>, Tuple2<Integer, String>>(
myReduceFunction, new TupleKeySelector<Tuple2<Integer, String>>(1),
distributedTriggers, evictions, triggers, null);
myReduceFunction, keySelector, distributedTriggers, evictions, triggers, null);
ArrayList<Tuple2<Integer, String>> result = new ArrayList<Tuple2<Integer, String>>();
for (Tuple2<Integer, String> t : MockContext.createAndExecute(invokable, inputs)) {
......@@ -398,7 +406,7 @@ public class GroupedWindowInvokableTest {
assertEquals(new HashSet<Tuple2<Integer, String>>(expected),
new HashSet<Tuple2<Integer, String>>(result));
assertEquals(expected.size(), result.size());
// repeat the test with central eviction. The result should be the same.
triggers.clear();
triggers.add(new TimeTriggerPolicy<Tuple2<Integer, String>>(2L, myTimeStamp, 2L));
......@@ -407,8 +415,8 @@ public class GroupedWindowInvokableTest {
centralEvictions.add(new TimeEvictionPolicy<Tuple2<Integer, String>>(4L, myTimeStamp));
invokable = new GroupedWindowInvokable<Tuple2<Integer, String>, Tuple2<Integer, String>>(
myReduceFunction, new TupleKeySelector<Tuple2<Integer, String>>(1),
distributedTriggers, evictions, triggers, centralEvictions);
myReduceFunction, keySelector, distributedTriggers, evictions, triggers,
centralEvictions);
result = new ArrayList<Tuple2<Integer, String>>();
for (Tuple2<Integer, String> t : MockContext.createAndExecute(invokable, inputs)) {
......
......@@ -20,11 +20,11 @@ package org.apache.flink.streaming.partitioner;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.streaming.api.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Before;
import org.junit.Test;
......@@ -42,7 +42,15 @@ public class FieldsPartitionerTest {
@Before
public void setPartitioner() {
fieldsPartitioner = new FieldsPartitioner<Tuple>(new TupleKeySelector<Tuple>(0));
fieldsPartitioner = new FieldsPartitioner<Tuple>(new KeySelector<Tuple, String>() {
private static final long serialVersionUID = 1L;
@Override
public String getKey(Tuple value) throws Exception {
return value.getField(0);
}
});
}
@Test
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.util;
import static org.junit.Assert.assertEquals;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.util.keys.FieldsKeySelector;
import org.apache.flink.streaming.util.keys.ObjectKeySelector;
import org.apache.flink.streaming.util.keys.TupleKeySelector;
import org.junit.Test;
public class FieldsKeySelectorTest {
@Test
public void testGetKey() throws Exception {
Integer i = 5;
Tuple2<Integer, String> t = new Tuple2<Integer, String>(-1, "a");
double[] a = new double[] { 0.0, 1.2 };
KeySelector<Integer, ?> ks1 = new ObjectKeySelector<Integer>();
assertEquals(ks1.getKey(i), 5);
KeySelector<Tuple2<Integer, String>, ?> ks3 = new TupleKeySelector<Tuple2<Integer, String>>(
1);
assertEquals(ks3.getKey(t), "a");
KeySelector<Tuple2<Integer, String>, ?> ks4 = FieldsKeySelector.getSelector(
TypeExtractor.getForObject(t), 1, 1);
assertEquals(ks4.getKey(t), new Tuple2<String, String>("a", "a"));
KeySelector<double[], ?> ks5 = FieldsKeySelector.getSelector(
TypeExtractor.getForObject(a), 0);
assertEquals(ks5.getKey(a), 0.0);
KeySelector<double[], ?> ks6 = FieldsKeySelector.getSelector(
TypeExtractor.getForObject(a), 1, 0);
assertEquals(ks6.getKey(a), new Tuple2<Double, Double>(1.2, 0.0));
}
}
......@@ -20,7 +20,6 @@ package org.apache.flink.examples.scala.streaming.windowing
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.streaming.StreamExecutionEnvironment
import org.apache.flink.streaming.api.function.source.SourceFunction
import org.apache.flink.util.Collector
import scala.util.Random
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.streaming
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import java.util.ArrayList
import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor
import org.apache.flink.api.java.functions.KeySelector
class CaseClassKeySelector[T <: Product](@transient val typeInfo: CaseClassTypeInfo[T],
val keyFields: String*) extends KeySelector[T, Seq[Any]] {
val numOfKeys: Int = keyFields.length;
@transient val fieldDescriptors = new ArrayList[FlatFieldDescriptor]();
for (field <- keyFields) {
typeInfo.getKey(field, 0, fieldDescriptors);
}
val logicalKeyPositions = new Array[Int](numOfKeys)
val orders = new Array[Boolean](numOfKeys)
for (i <- 0 to numOfKeys - 1) {
logicalKeyPositions(i) = fieldDescriptors.get(i).getPosition();
}
def getKey(value: T): Seq[Any] = {
for (i <- 0 to numOfKeys - 1) yield value.productElement(logicalKeyPositions(i))
}
}
......@@ -116,7 +116,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
*
*/
def groupBy(fields: Int*): DataStream[T] =
new DataStream[T](javaStream.groupBy(new FieldsKeySelector[T](fields: _*)))
new DataStream[T](javaStream.groupBy(fields: _*))
/**
* Groups the elements of a DataStream by the given field expressions to
......@@ -124,12 +124,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
*
*/
def groupBy(firstField: String, otherFields: String*): DataStream[T] =
javaStream.getType() match {
case ccInfo: CaseClassTypeInfo[T] => new DataStream[T](javaStream.groupBy(
new CaseClassKeySelector[T](ccInfo, firstField +: otherFields.toArray: _*)))
case _ => new DataStream[T](javaStream.groupBy(
firstField +: otherFields.toArray: _*))
}
new DataStream[T](javaStream.groupBy(firstField +: otherFields.toArray: _*))
/**
* Groups the elements of a DataStream by the given K key to
......@@ -152,7 +147,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
*
*/
def partitionBy(fields: Int*): DataStream[T] =
new DataStream[T](javaStream.partitionBy(new FieldsKeySelector[T](fields: _*)))
new DataStream[T](javaStream.partitionBy(fields: _*));
/**
* Sets the partitioning of the DataStream so that the output is
......@@ -161,12 +156,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
*
*/
def partitionBy(firstField: String, otherFields: String*): DataStream[T] =
javaStream.getType() match {
case ccInfo: CaseClassTypeInfo[T] => new DataStream[T](javaStream.partitionBy(
new CaseClassKeySelector[T](ccInfo, firstField +: otherFields.toArray: _*)))
case _ => new DataStream[T](javaStream.partitionBy(
firstField +: otherFields.toArray: _*))
}
new DataStream[T](javaStream.partitionBy(firstField +: otherFields.toArray: _*))
/**
* Sets the partitioning of the DataStream so that the output is
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.streaming
import org.apache.flink.streaming.util.keys.{ FieldsKeySelector => JavaSelector }
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.tuple.Tuple
class FieldsKeySelector[IN](fields: Int*) extends KeySelector[IN, Seq[Any]] {
override def getKey(value: IN): Seq[Any] =
value match {
case prod: Product =>
for (i <- 0 to fields.length - 1) yield prod.productElement(fields(i))
case tuple: Tuple =>
for (i <- 0 to fields.length - 1) yield tuple.getField(fields(i))
case _ => throw new RuntimeException("Only tuple types are supported")
}
}
......@@ -23,17 +23,16 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.scala.ClosureCleaner
import org.apache.flink.api.scala.typeutils.CaseClassSerializer
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.streaming.api.datastream.{ DataStream => JavaStream }
import org.apache.flink.streaming.api.datastream.TemporalOperator
import org.apache.flink.streaming.api.function.co.JoinWindowFunction
import org.apache.flink.streaming.util.keys.PojoKeySelector
import scala.reflect.ClassTag
import org.apache.commons.lang.Validate
import org.apache.flink.streaming.api.invokable.operator.co.CoWindowInvokable
import org.apache.flink.streaming.api.function.co.CrossWindowFunction
import org.apache.flink.api.common.functions.CrossFunction
import org.apache.flink.api.scala.typeutils.CaseClassSerializer
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
class StreamCrossOperator[I1, I2](i1: JavaStream[I1], i2: JavaStream[I2]) extends
TemporalOperator[I1, I2, StreamCrossOperator.CrossWindow[I1, I2]](i1, i2) {
......
......@@ -28,10 +28,11 @@ import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.streaming.api.datastream.{ DataStream => JavaStream }
import org.apache.flink.streaming.api.datastream.TemporalOperator
import org.apache.flink.streaming.api.function.co.JoinWindowFunction
import org.apache.flink.streaming.util.keys.PojoKeySelector
import scala.reflect.ClassTag
import org.apache.commons.lang.Validate
import org.apache.flink.streaming.api.invokable.operator.co.CoWindowInvokable
import org.apache.flink.streaming.util.keys.KeySelectorUtil
import org.apache.flink.api.java.operators.Keys
class StreamJoinOperator[I1, I2](i1: JavaStream[I1], i2: JavaStream[I2]) extends
TemporalOperator[I1, I2, StreamJoinOperator.JoinWindow[I1, I2]](i1, i2) {
......@@ -43,8 +44,10 @@ TemporalOperator[I1, I2, StreamJoinOperator.JoinWindow[I1, I2]](i1, i2) {
object StreamJoinOperator {
class JoinWindow[I1, I2](op: StreamJoinOperator[I1, I2]) {
class JoinWindow[I1, I2](private[flink] op: StreamJoinOperator[I1, I2]) {
private[flink] val type1 = op.input1.getType();
/**
* Continues a temporal Join transformation by defining
* the fields in the first stream to be used as keys for the join.
......@@ -52,7 +55,8 @@ object StreamJoinOperator {
* to define the second key.
*/
def where(fields: Int*) = {
new JoinPredicate[I1, I2](op, new FieldsKeySelector[I1](fields: _*))
new JoinPredicate[I1, I2](op, KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys(fields.toArray,type1),type1))
}
/**
......@@ -62,12 +66,8 @@ object StreamJoinOperator {
* to define the second key.
*/
def where(firstField: String, otherFields: String*) =
op.input1.getType() match {
case ccInfo: CaseClassTypeInfo[I1] => new JoinPredicate[I1, I2](op,
new CaseClassKeySelector[I1](ccInfo, firstField +: otherFields.toArray: _*))
case _ => new JoinPredicate[I1, I2](op, new PojoKeySelector[I1](
op.input1.getType(), (firstField +: otherFields): _*))
}
new JoinPredicate[I1, I2](op, KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys(firstField +: otherFields.toArray,type1),type1))
/**
* Continues a temporal Join transformation by defining
......@@ -90,6 +90,7 @@ object StreamJoinOperator {
class JoinPredicate[I1, I2](private[flink] val op: StreamJoinOperator[I1, I2],
private[flink] val keys1: KeySelector[I1, _]) {
private[flink] var keys2: KeySelector[I2, _] = null
private[flink] val type2 = op.input2.getType();
/**
* Creates a temporal join transformation by defining the second join key.
......@@ -98,7 +99,8 @@ object StreamJoinOperator {
* To define a custom wrapping, use JoinedStream.apply(...)
*/
def equalTo(fields: Int*): JoinedStream[I1, I2] = {
finish(new FieldsKeySelector[I2](fields: _*))
finish(KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys(fields.toArray,type2),type2))
}
/**
......@@ -108,12 +110,8 @@ object StreamJoinOperator {
* To define a custom wrapping, use JoinedStream.apply(...)
*/
def equalTo(firstField: String, otherFields: String*): JoinedStream[I1, I2] =
op.input2.getType() match {
case ccInfo: CaseClassTypeInfo[I2] => finish(
new CaseClassKeySelector[I2](ccInfo, firstField +: otherFields.toArray: _*))
case _ => finish(new PojoKeySelector[I2](op.input2.getType(),
(firstField +: otherFields): _*))
}
finish(KeySelectorUtil.getSelectorForKeys(
new Keys.ExpressionKeys(firstField +: otherFields.toArray,type2),type2))
/**
* Creates a temporal join transformation by defining the second join key.
......
......@@ -66,7 +66,7 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
*
*/
def groupBy(fields: Int*): WindowedDataStream[T] =
new WindowedDataStream[T](javaStream.groupBy(new FieldsKeySelector[T](fields: _*)))
new WindowedDataStream[T](javaStream.groupBy(fields: _*))
/**
* Groups the elements of the WindowedDataStream using the given
......@@ -78,13 +78,8 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
*
*/
def groupBy(firstField: String, otherFields: String*): WindowedDataStream[T] =
javaStream.getType() match {
case ccInfo: CaseClassTypeInfo[T] => new WindowedDataStream[T](javaStream.groupBy(
new CaseClassKeySelector[T](ccInfo, firstField +: otherFields.toArray: _*)))
case _ => new WindowedDataStream[T](javaStream.groupBy(
new WindowedDataStream[T](javaStream.groupBy(
firstField +: otherFields.toArray: _*))
}
/**
* Groups the elements of the WindowedDataStream using the given
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册