diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 078679ddb4b9e05ac1e570511811b8c76818c2b5..9074b7aa1da5523d5a9f424023d856721c0031ed 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -343,6 +343,15 @@ public abstract class AbstractStreamOperator } } } + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContext(Object key) { + if (keyValueStates != null) { + for (KvState kv : keyValueStates) { + kv.setCurrentKey(key); + } + } + } // ------------------------------------------------------------------------ // Context and chaining properties diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java index 3165f888e3f355505f15295d101e047c42defbc7..90d3d82bf6b32d73b7280491f1ab3cbb9d167b9d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java @@ -239,7 +239,7 @@ public abstract class AbstractAlignedProcessingTimeWindowOperator { public abstract void addElementToLatestPane(Type element) throws Exception; - public abstract void evaluateWindow(Collector out, TimeWindow window) throws Exception; + public abstract void evaluateWindow(Collector out, TimeWindow window, AbstractStreamOperator operator) throws Exception; public void dispose() { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingKeyedTimePanes.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingKeyedTimePanes.java index c854e6c9c6af4e5f06ab3b1095bca3ee2ac5b70a..e15de8e047b5da5c77dfe5ef443d9841f108bb12 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingKeyedTimePanes.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingKeyedTimePanes.java @@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.operators.windowing; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.runtime.util.UnionIterator; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.util.Collector; @@ -57,16 +58,21 @@ public class AccumulatingKeyedTimePanes extends AbstractKeyed } @Override - public void evaluateWindow(Collector out, TimeWindow window) throws Exception { + public void evaluateWindow(Collector out, TimeWindow window, + AbstractStreamOperator operator) throws Exception + { if (previousPanes.isEmpty()) { // optimized path for single pane case (tumbling window) for (KeyMap.Entry> entry : latestPane) { + Key key = entry.getKey(); + operator.setKeyContext(key); function.apply(entry.getKey(), window, entry.getValue(), out); } } else { // general code path for multi-pane case - WindowFunctionTraversal evaluator = new WindowFunctionTraversal<>(function, window, out); + WindowFunctionTraversal evaluator = new WindowFunctionTraversal<>( + function, window, out, operator); traverseAllPanes(evaluator, evaluationPass); } @@ -84,16 +90,21 @@ public class AccumulatingKeyedTimePanes extends AbstractKeyed private final UnionIterator unionIterator; private final Collector out; + + private final TimeWindow window; + + private final AbstractStreamOperator contextOperator; private Key currentKey; + - private TimeWindow window; - - WindowFunctionTraversal(WindowFunction function, TimeWindow window, Collector out) { + WindowFunctionTraversal(WindowFunction function, TimeWindow window, + Collector out, AbstractStreamOperator contextOperator) { this.function = function; this.out = out; this.unionIterator = new UnionIterator<>(); this.window = window; + this.contextOperator = contextOperator; } @@ -110,6 +121,7 @@ public class AccumulatingKeyedTimePanes extends AbstractKeyed @Override public void keyDone() throws Exception { + contextOperator.setKeyContext(currentKey); function.apply(currentKey, window, unionIterator, out); } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingKeyedTimePanes.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingKeyedTimePanes.java index d395b2af3bcf0cb9238be7358196a7a2e9c78667..8599bc17ed6aaec555cfc4149897dde3fbe6d78c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingKeyedTimePanes.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingKeyedTimePanes.java @@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.operators.windowing; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.util.Collector; @@ -50,7 +51,8 @@ public class AggregatingKeyedTimePanes extends AbstractKeyedTimePanes } @Override - public void evaluateWindow(Collector out, TimeWindow window) throws Exception { + public void evaluateWindow(Collector out, TimeWindow window, + AbstractStreamOperator operator) throws Exception { if (previousPanes.isEmpty()) { // optimized path for single pane case for (KeyMap.Entry entry : latestPane) { @@ -59,7 +61,7 @@ public class AggregatingKeyedTimePanes extends AbstractKeyedTimePanes } else { // general code path for multi-pane case - AggregatingTraversal evaluator = new AggregatingTraversal<>(reducer, out); + AggregatingTraversal evaluator = new AggregatingTraversal<>(reducer, out, operator); traverseAllPanes(evaluator, evaluationPass); } @@ -76,16 +78,21 @@ public class AggregatingKeyedTimePanes extends AbstractKeyedTimePanes private final Collector out; + private final AbstractStreamOperator operator; + private Type currentValue; - AggregatingTraversal(ReduceFunction function, Collector out) { + AggregatingTraversal(ReduceFunction function, Collector out, + AbstractStreamOperator operator) { this.function = function; this.out = out; + this.operator = operator; } @Override public void startNewKey(Key key) { currentValue = null; + operator.setKeyContext(key); } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java index ad3c838bb5f8013e609bef0aa10e2a39345f24be..62eb268dac38baf2b7ac16121c522d767f126e16 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java @@ -20,11 +20,15 @@ package org.apache.flink.streaming.runtime.operators.windowing; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; @@ -48,7 +52,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -89,6 +95,13 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { // ------------------------------------------------------------------------ + public AccumulatingAlignedProcessingTimeWindowOperatorTest() { + ClosureCleaner.clean(identitySelector, false); + ClosureCleaner.clean(validatingIdentityFunction, false); + } + + // ------------------------------------------------------------------------ + @After public void checkNoTriggerThreadsRunning() { // make sure that all the threads we trigger are shut down @@ -544,6 +557,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { // inject some elements final int numElementsFirst = 700; + final int numElements = 1000; for (int i = 0; i < numElementsFirst; i++) { synchronized (lock) { op.processElement(new StreamRecord(i)); @@ -560,6 +574,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { resultAtSnapshot = new ArrayList<>(out.getElements()); int afterSnapShot = out.getElements().size(); assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); + assertTrue(afterSnapShot <= numElementsFirst); } // inject some random elements, which should not show up in the state @@ -584,7 +599,6 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { op.open(); // inject some more elements - final int numElements = 1000; for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { op.processElement(new StreamRecord(i)); @@ -725,6 +739,64 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { } } + @Test + public void testKeyValueStateInWindowFunction() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + try { + final CollectingOutput out = new CollectingOutput<>(50); + final Object lock = new Object(); + final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); + + StatefulFunction.globalCounts.clear(); + + // tumbling window that triggers every 20 milliseconds + AccumulatingProcessingTimeWindowOperator op = + new AccumulatingProcessingTimeWindowOperator<>( + new StatefulFunction(), identitySelector, + IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50); + + op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE), out); + op.open(); + + synchronized (lock) { + op.processElement(new StreamRecord(1)); + op.processElement(new StreamRecord(2)); + } + out.waitForNElements(2, 60000); + + synchronized (lock) { + op.processElement(new StreamRecord(1)); + op.processElement(new StreamRecord(2)); + op.processElement(new StreamRecord(1)); + op.processElement(new StreamRecord(1)); + op.processElement(new StreamRecord(2)); + op.processElement(new StreamRecord(2)); + } + out.waitForNElements(8, 60000); + + List result = out.getElements(); + assertEquals(8, result.size()); + + Collections.sort(result); + assertEquals(Arrays.asList(1, 1, 1, 1, 2, 2, 2, 2), result); + + assertEquals(4, StatefulFunction.globalCounts.get(1).intValue()); + assertEquals(4, StatefulFunction.globalCounts.get(2).intValue()); + + synchronized (lock) { + op.close(); + } + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + finally { + timerService.shutdown(); + } + } + // ------------------------------------------------------------------------ private void assertInvalidParameter(long windowSize, long windowSlide) { @@ -771,6 +843,41 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { } } + // ------------------------------------------------------------------------ + + private static class StatefulFunction extends RichWindowFunction { + + // we use a concurrent map here even though there is no concurrency, to + // get "volatile" style access to entries + static final Map globalCounts = new ConcurrentHashMap<>(); + + private OperatorState state; + + @Override + public void open(Configuration parameters) { + assertNotNull(getRuntimeContext()); + state = getRuntimeContext().getKeyValueState("totalCount", Integer.class, 0); + } + + @Override + public void apply(Integer key, + TimeWindow window, + Iterable values, + Collector out) throws Exception { + for (Integer i : values) { + // we need to update this state before emitting elements. Else, the test's main + // thread will have received all output elements before the state is updated and + // the checks may fail + state.update(state.value() + 1); + globalCounts.put(key, state.value()); + + out.collect(i); + } + } + } + + // ------------------------------------------------------------------------ + private static StreamTask createMockTask() { StreamTask task = mock(StreamTask.class); when(task.getAccumulatorMap()).thenReturn(new HashMap>()); @@ -821,4 +928,11 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { return mockTask; } + + private static StreamConfig createTaskConfig(KeySelector partitioner, TypeSerializer keySerializer) { + StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setStatePartitioner(partitioner); + cfg.setStateKeySerializer(keySerializer); + return cfg; + } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java index 4bd260fe877d5c63b6ff9bc7f57d66956152771f..4d507fb0919dc2c48f8b9d12aeffb1adbebe0d14 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java @@ -21,9 +21,16 @@ package org.apache.flink.streaming.runtime.operators.windowing; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichReduceFunction; +import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -33,8 +40,8 @@ import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; - import org.apache.flink.streaming.runtime.tasks.StreamTaskState; + import org.junit.After; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; @@ -44,15 +51,19 @@ import org.mockito.stubbing.OngoingStubbing; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -70,20 +81,41 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { @SuppressWarnings("unchecked") private final KeySelector mockKeySelector = mock(KeySelector.class); - private final KeySelector identitySelector = new KeySelector() { + private final KeySelector, Integer> fieldOneSelector = + new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) { + return value.f0; + } + }; + + private final ReduceFunction> sumFunction = new ReduceFunction>() { @Override - public Integer getKey(Integer value) { - return value; + public Tuple2 reduce(Tuple2 value1, Tuple2 value2) { + return new Tuple2<>(value1.f0, value1.f1 + value2.f1); } }; - private final ReduceFunction sumFunction = new ReduceFunction() { + private final TypeSerializer> tupleSerializer = + new TupleTypeInfo>(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO) + .createSerializer(new ExecutionConfig()); + + private final Comparator> tupleComparator = new Comparator>() { @Override - public Integer reduce(Integer value1, Integer value2) { - return value1 + value2; + public int compare(Tuple2 o1, Tuple2 o2) { + int diff0 = o1.f0 - o2.f0; + int diff1 = o1.f1 - o2.f1; + return diff0 != 0 ? diff0 : diff1; } }; + + // ------------------------------------------------------------------------ + public AggregatingAlignedProcessingTimeWindowOperatorTest() { + ClosureCleaner.clean(fieldOneSelector, false); + ClosureCleaner.clean(sumFunction, false); + } + // ------------------------------------------------------------------------ @After @@ -211,12 +243,12 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { final int windowSize = 50; - final CollectingOutput out = new CollectingOutput<>(windowSize); + final CollectingOutput> out = new CollectingOutput<>(windowSize); - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSize); final Object lock = new Object(); @@ -229,7 +261,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -240,12 +274,13 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // get and verify the result - List result = out.getElements(); + List> result = out.getElements(); assertEquals(numElements, result.size()); - Collections.sort(result); + Collections.sort(result, tupleComparator); for (int i = 0; i < numElements; i++) { - assertEquals(i, result.get(i).intValue()); + assertEquals(i, result.get(i).f0.intValue()); + assertEquals(i, result.get(i).f1.intValue()); } } catch (Exception e) { @@ -263,15 +298,15 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { try { final int windowSize = 50; - final CollectingOutput out = new CollectingOutput<>(windowSize); + final CollectingOutput> out = new CollectingOutput<>(windowSize); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); - - AggregatingProcessingTimeWindowOperator op = + + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSize); op.setup(mockTask, new StreamConfig(new Configuration()), out); @@ -286,8 +321,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { synchronized (lock) { long nextTime = op.getNextEvaluationTime(); int val = ((int) nextTime) ^ ((int) (nextTime >>> 32)); - - op.processElement(new StreamRecord(val)); + + StreamRecord> next = new StreamRecord<>(new Tuple2<>(val, val)); + op.setKeyContextElement(next); + op.processElement(next); if (nextTime != previousNextTime) { window++; @@ -302,14 +339,14 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { } op.dispose(); - List result = out.getElements(); + List> result = out.getElements(); // we have ideally one element per window. we may have more, when we emitted a value into the // successive window (corner case), so we can have twice the number of elements, in the worst case. assertTrue(result.size() >= numWindows && result.size() <= 2 * numWindows); // deduplicate for more accurate checks - HashSet set = new HashSet<>(result); + HashSet> set = new HashSet<>(result); assertTrue(set.size() == 10); } catch (Exception e) { @@ -325,16 +362,16 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { public void testSlidingWindow() { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { - final CollectingOutput out = new CollectingOutput<>(50); + final CollectingOutput> out = new CollectingOutput<>(50); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); // tumbling window that triggers every 20 milliseconds - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, 150, 50); op.setup(mockTask, new StreamConfig(new Configuration()), out); @@ -344,7 +381,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -355,7 +394,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // get and verify the result - List result = out.getElements(); + List> result = out.getElements(); // every element can occur between one and three times if (result.size() < numElements || result.size() > 3 * numElements) { @@ -363,17 +402,19 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { fail("Wrong number of results: " + result.size()); } - Collections.sort(result); + Collections.sort(result, tupleComparator); int lastNum = -1; int lastCount = -1; - for (int num : result) { - if (num == lastNum) { + for (Tuple2 val : result) { + assertEquals(val.f0, val.f1); + + if (val.f0 == lastNum) { lastCount++; assertTrue(lastCount <= 3); } else { - lastNum = num; + lastNum = val.f0; lastCount = 1; } } @@ -392,33 +433,45 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { - final CollectingOutput out = new CollectingOutput<>(50); + final CollectingOutput> out = new CollectingOutput<>(50); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); // tumbling window that triggers every 20 milliseconds - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, 150, 50); op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); synchronized (lock) { - op.processElement(new StreamRecord(1)); - op.processElement(new StreamRecord(2)); + StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, 1)); + op.setKeyContextElement(next1); + op.processElement(next1); + + StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, 2)); + op.setKeyContextElement(next2); + op.processElement(next2); } // each element should end up in the output three times // wait until the elements have arrived 6 times in the output out.waitForNElements(6, 120000); - List result = out.getElements(); + List> result = out.getElements(); assertEquals(6, result.size()); - Collections.sort(result); - assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result); + Collections.sort(result, tupleComparator); + assertEquals(Arrays.asList( + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(2, 2), + new Tuple2<>(2, 2), + new Tuple2<>(2, 2) + ), result); synchronized (lock) { op.close(); @@ -438,15 +491,16 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { public void testEmitTrailingDataOnClose() { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { - final CollectingOutput out = new CollectingOutput<>(); + final CollectingOutput> out = new CollectingOutput<>(); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); // the operator has a window time that is so long that it will not fire in this test final long oneYear = 365L * 24 * 60 * 60 * 1000; - AggregatingProcessingTimeWindowOperator op = - new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, oneYear, oneYear); + AggregatingProcessingTimeWindowOperator> op = + new AggregatingProcessingTimeWindowOperator<>( + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, oneYear, oneYear); op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); @@ -454,7 +508,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { List data = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); for (Integer i : data) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } } @@ -464,9 +520,14 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // get and verify the result - List result = out.getElements(); - Collections.sort(result); - assertEquals(data, result); + List> result = out.getElements(); + assertEquals(data.size(), result.size()); + + Collections.sort(result, tupleComparator); + for (int i = 0; i < data.size(); i++) { + assertEquals(data.get(i), result.get(i).f0); + assertEquals(data.get(i), result.get(i).f1); + } } catch (Exception e) { e.printStackTrace(); @@ -481,18 +542,18 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { public void testPropagateExceptionsFromProcessElement() { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { - final CollectingOutput out = new CollectingOutput<>(); + final CollectingOutput> out = new CollectingOutput<>(); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); - ReduceFunction failingFunction = new FailingFunction(100); + ReduceFunction> failingFunction = new FailingFunction(100); // the operator has a window time that is so long that it will not fire in this test final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000; - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - failingFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + failingFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, hundredYears, hundredYears); op.setup(mockTask, new StreamConfig(new Configuration()), out); @@ -500,12 +561,16 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < 100; i++) { synchronized (lock) { - op.processElement(new StreamRecord(1)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(1, 1)); + op.setKeyContextElement(next); + op.processElement(next); } } try { - op.processElement(new StreamRecord(1)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(1, 1)); + op.setKeyContextElement(next); + op.processElement(next); fail("This fail with an exception"); } catch (Exception e) { @@ -528,15 +593,15 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { final int windowSize = 200; - final CollectingOutput out = new CollectingOutput<>(windowSize); + final CollectingOutput> out = new CollectingOutput<>(windowSize); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); // tumbling window that triggers every 50 milliseconds - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSize); op.setup(mockTask, new StreamConfig(new Configuration()), out); @@ -548,14 +613,16 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElementsFirst; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } // draw a snapshot and dispose the window StreamTaskState state; - List resultAtSnapshot; + List> resultAtSnapshot; synchronized (lock) { int beforeSnapShot = out.getElements().size(); state = op.snapshotOperatorState(1L, System.currentTimeMillis()); @@ -569,7 +636,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // inject some random elements, which should not show up in the state for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -577,10 +646,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // re-create the operator and restore the state - final CollectingOutput out2 = new CollectingOutput<>(windowSize); + final CollectingOutput> out2 = new CollectingOutput<>(windowSize); op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSize); op.setup(mockTask, new StreamConfig(new Configuration()), out2); @@ -590,7 +659,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // inject the remaining elements for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -601,13 +672,14 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // get and verify the result - List finalResult = new ArrayList<>(resultAtSnapshot); + List> finalResult = new ArrayList<>(resultAtSnapshot); finalResult.addAll(out2.getElements()); assertEquals(numElements, finalResult.size()); - Collections.sort(finalResult); + Collections.sort(finalResult, tupleComparator); for (int i = 0; i < numElements; i++) { - assertEquals(i, finalResult.get(i).intValue()); + assertEquals(i, finalResult.get(i).f0.intValue()); + assertEquals(i, finalResult.get(i).f1.intValue()); } } catch (Exception e) { @@ -627,15 +699,15 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { final int windowSlide = 50; final int windowSize = factor * windowSlide; - final CollectingOutput out = new CollectingOutput<>(windowSlide); + final CollectingOutput> out = new CollectingOutput<>(windowSlide); final Object lock = new Object(); final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); // sliding window (200 msecs) every 50 msecs - AggregatingProcessingTimeWindowOperator op = + AggregatingProcessingTimeWindowOperator> op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSlide); op.setup(mockTask, new StreamConfig(new Configuration()), out); @@ -647,14 +719,16 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { for (int i = 0; i < numElementsFirst; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } // draw a snapshot StreamTaskState state; - List resultAtSnapshot; + List> resultAtSnapshot; synchronized (lock) { int beforeSnapShot = out.getElements().size(); state = op.snapshotOperatorState(1L, System.currentTimeMillis()); @@ -668,7 +742,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // inject the remaining elements - these should not influence the snapshot for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -676,10 +752,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // re-create the operator and restore the state - final CollectingOutput out2 = new CollectingOutput<>(windowSlide); + final CollectingOutput> out2 = new CollectingOutput<>(windowSlide); op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, + sumFunction, fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSlide); op.setup(mockTask, new StreamConfig(new Configuration()), out2); @@ -690,7 +766,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // inject again the remaining elements for (int i = numElementsFirst; i < numElements; i++) { synchronized (lock) { - op.processElement(new StreamRecord(i)); + StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); + op.setKeyContextElement(next); + op.processElement(next); } Thread.sleep(1); } @@ -710,13 +788,14 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { op.dispose(); // get and verify the result - List finalResult = new ArrayList<>(resultAtSnapshot); + List> finalResult = new ArrayList<>(resultAtSnapshot); finalResult.addAll(out2.getElements()); assertEquals(factor * numElements, finalResult.size()); - Collections.sort(finalResult); + Collections.sort(finalResult, tupleComparator); for (int i = 0; i < factor * numElements; i++) { - assertEquals(i / factor, finalResult.get(i).intValue()); + assertEquals(i / factor, finalResult.get(i).f0.intValue()); + assertEquals(i / factor, finalResult.get(i).f1.intValue()); } } catch (Exception e) { @@ -727,6 +806,134 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { timerService.shutdown(); } } + + @Test + public void testKeyValueStateInWindowFunctionTumbling() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + try { + final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000; + + final CollectingOutput> out = new CollectingOutput<>(); + final Object lock = new Object(); + final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); + + StatefulFunction.globalCounts.clear(); + + AggregatingProcessingTimeWindowOperator> op = + new AggregatingProcessingTimeWindowOperator<>( + new StatefulFunction(), fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, hundredYears, hundredYears); + + op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE), out); + op.open(); + + // because the window interval is so large, everything should be in one window + // and aggregate into one value per key + + synchronized (lock) { + for (int i = 0; i < 10; i++) { + StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, i)); + op.setKeyContextElement(next1); + op.processElement(next1); + + StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, i)); + op.setKeyContextElement(next2); + op.processElement(next2); + } + + op.close(); + } + + List> result = out.getElements(); + assertEquals(2, result.size()); + + Collections.sort(result, tupleComparator); + assertEquals(45, result.get(0).f1.intValue()); + assertEquals(45, result.get(1).f1.intValue()); + + assertEquals(10, StatefulFunction.globalCounts.get(1).intValue()); + assertEquals(10, StatefulFunction.globalCounts.get(2).intValue()); + + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + finally { + timerService.shutdown(); + } + } + + @Test + public void testKeyValueStateInWindowFunctionSliding() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + try { + final int factor = 2; + final int windowSlide = 50; + final int windowSize = factor * windowSlide; + + final CollectingOutput> out = new CollectingOutput<>(); + final Object lock = new Object(); + final StreamTask mockTask = createMockTaskWithTimer(timerService, lock); + + StatefulFunction.globalCounts.clear(); + + AggregatingProcessingTimeWindowOperator> op = + new AggregatingProcessingTimeWindowOperator<>( + new StatefulFunction(), fieldOneSelector, + IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSlide); + + op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE), out); + op.open(); + + // because the window interval is so large, everything should be in one window + // and aggregate into one value per key + final int numElements = 100; + + // because we do not release the lock here, these elements + for (int i = 0; i < numElements; i++) { + + StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, i)); + StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, i)); + StreamRecord> next3 = new StreamRecord<>(new Tuple2<>(1, i)); + StreamRecord> next4 = new StreamRecord<>(new Tuple2<>(2, i)); + + // because we do not release the lock between elements, they end up in the same windows + synchronized (lock) { + op.setKeyContextElement(next1); + op.processElement(next1); + op.setKeyContextElement(next2); + op.processElement(next2); + op.setKeyContextElement(next3); + op.processElement(next3); + op.setKeyContextElement(next4); + op.processElement(next4); + } + + Thread.sleep(1); + } + + synchronized (lock) { + op.close(); + } + + int count1 = StatefulFunction.globalCounts.get(1); + int count2 = StatefulFunction.globalCounts.get(2); + + assertTrue(count1 >= 2 && count1 <= 2 * numElements); + assertEquals(count1, count2); + + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + finally { + timerService.shutdown(); + } + } // ------------------------------------------------------------------------ @@ -748,7 +955,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { // ------------------------------------------------------------------------ - private static class FailingFunction implements ReduceFunction { + private static class FailingFunction implements ReduceFunction> { private final int failAfterElements; @@ -759,16 +966,44 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { } @Override - public Integer reduce(Integer value1, Integer value2) throws Exception { + public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { numElements++; if (numElements >= failAfterElements) { throw new Exception("Artificial Test Exception"); } - return value1 + value2; + return new Tuple2<>(value1.f0, value1.f1 + value2.f1); } } + + // ------------------------------------------------------------------------ + + private static class StatefulFunction extends RichReduceFunction> { + + static final Map globalCounts = new ConcurrentHashMap<>(); + + private OperatorState state; + + @Override + public void open(Configuration parameters) { + assertNotNull(getRuntimeContext()); + + // start with one, so the final count is correct and we test that we do not + // initialize with 0 always by default + state = getRuntimeContext().getKeyValueState("totalCount", Integer.class, 1); + } + + @Override + public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { + state.update(state.value() + 1); + globalCounts.put(value1.f0, state.value()); + + return new Tuple2<>(value1.f0, value1.f1 + value2.f1); + } + } + + // ------------------------------------------------------------------------ private static StreamTask createMockTask() { StreamTask task = mock(StreamTask.class); @@ -820,4 +1055,11 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest { return mockTask; } + + private static StreamConfig createTaskConfig(KeySelector partitioner, TypeSerializer keySerializer) { + StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setStatePartitioner(partitioner); + cfg.setStateKeySerializer(keySerializer); + return cfg; + } }