提交 3dc7423c 编写于 作者: S Stephan Ewen

[FLINK-2891] [streaming] Set keys for key/value state in window evaluation of fast-path windows.

上级 6424ce57
......@@ -343,6 +343,15 @@ public abstract class AbstractStreamOperator<OUT>
}
}
}
@SuppressWarnings({"unchecked", "rawtypes"})
public void setKeyContext(Object key) {
if (keyValueStates != null) {
for (KvState kv : keyValueStates) {
kv.setCurrentKey(key);
}
}
}
// ------------------------------------------------------------------------
// Context and chaining properties
......
......@@ -239,7 +239,7 @@ public abstract class AbstractAlignedProcessingTimeWindowOperator<KEY, IN, OUT,
private void computeWindow(long timestamp) throws Exception {
out.setTimestamp(timestamp);
panes.truncatePanes(numPanesPerWindow);
panes.evaluateWindow(out, new TimeWindow(timestamp, timestamp + windowSize));
panes.evaluateWindow(out, new TimeWindow(timestamp, timestamp + windowSize), this);
}
// ------------------------------------------------------------------------
......
......@@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector;
......@@ -47,7 +48,7 @@ public abstract class AbstractKeyedTimePanes<Type, Key, Aggregate, Result> {
public abstract void addElementToLatestPane(Type element) throws Exception;
public abstract void evaluateWindow(Collector<Result> out, TimeWindow window) throws Exception;
public abstract void evaluateWindow(Collector<Result> out, TimeWindow window, AbstractStreamOperator<Result> operator) throws Exception;
public void dispose() {
......
......@@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.util.UnionIterator;
import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.util.Collector;
......@@ -57,16 +58,21 @@ public class AccumulatingKeyedTimePanes<Type, Key, Result> extends AbstractKeyed
}
@Override
public void evaluateWindow(Collector<Result> out, TimeWindow window) throws Exception {
public void evaluateWindow(Collector<Result> out, TimeWindow window,
AbstractStreamOperator<Result> operator) throws Exception
{
if (previousPanes.isEmpty()) {
// optimized path for single pane case (tumbling window)
for (KeyMap.Entry<Key, ArrayList<Type>> entry : latestPane) {
Key key = entry.getKey();
operator.setKeyContext(key);
function.apply(entry.getKey(), window, entry.getValue(), out);
}
}
else {
// general code path for multi-pane case
WindowFunctionTraversal<Key, Type, Result> evaluator = new WindowFunctionTraversal<>(function, window, out);
WindowFunctionTraversal<Key, Type, Result> evaluator = new WindowFunctionTraversal<>(
function, window, out, operator);
traverseAllPanes(evaluator, evaluationPass);
}
......@@ -84,16 +90,21 @@ public class AccumulatingKeyedTimePanes<Type, Key, Result> extends AbstractKeyed
private final UnionIterator<Type> unionIterator;
private final Collector<Result> out;
private final TimeWindow window;
private final AbstractStreamOperator<Result> contextOperator;
private Key currentKey;
private TimeWindow window;
WindowFunctionTraversal(WindowFunction<Type, Result, Key, Window> function, TimeWindow window, Collector<Result> out) {
WindowFunctionTraversal(WindowFunction<Type, Result, Key, Window> function, TimeWindow window,
Collector<Result> out, AbstractStreamOperator<Result> contextOperator) {
this.function = function;
this.out = out;
this.unionIterator = new UnionIterator<>();
this.window = window;
this.contextOperator = contextOperator;
}
......@@ -110,6 +121,7 @@ public class AccumulatingKeyedTimePanes<Type, Key, Result> extends AbstractKeyed
@Override
public void keyDone() throws Exception {
contextOperator.setKeyContext(currentKey);
function.apply(currentKey, window, unionIterator, out);
}
}
......
......@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector;
......@@ -50,7 +51,8 @@ public class AggregatingKeyedTimePanes<Type, Key> extends AbstractKeyedTimePanes
}
@Override
public void evaluateWindow(Collector<Type> out, TimeWindow window) throws Exception {
public void evaluateWindow(Collector<Type> out, TimeWindow window,
AbstractStreamOperator<Type> operator) throws Exception {
if (previousPanes.isEmpty()) {
// optimized path for single pane case
for (KeyMap.Entry<Key, Type> entry : latestPane) {
......@@ -59,7 +61,7 @@ public class AggregatingKeyedTimePanes<Type, Key> extends AbstractKeyedTimePanes
}
else {
// general code path for multi-pane case
AggregatingTraversal<Key, Type> evaluator = new AggregatingTraversal<>(reducer, out);
AggregatingTraversal<Key, Type> evaluator = new AggregatingTraversal<>(reducer, out, operator);
traverseAllPanes(evaluator, evaluationPass);
}
......@@ -76,16 +78,21 @@ public class AggregatingKeyedTimePanes<Type, Key> extends AbstractKeyedTimePanes
private final Collector<Type> out;
private final AbstractStreamOperator<Type> operator;
private Type currentValue;
AggregatingTraversal(ReduceFunction<Type> function, Collector<Type> out) {
AggregatingTraversal(ReduceFunction<Type> function, Collector<Type> out,
AbstractStreamOperator<Type> operator) {
this.function = function;
this.out = out;
this.operator = operator;
}
@Override
public void startNewKey(Key key) {
currentValue = null;
operator.setKeyContext(key);
}
@Override
......
......@@ -20,11 +20,15 @@ package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.Output;
......@@ -48,7 +52,9 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
......@@ -89,6 +95,13 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
// ------------------------------------------------------------------------
public AccumulatingAlignedProcessingTimeWindowOperatorTest() {
ClosureCleaner.clean(identitySelector, false);
ClosureCleaner.clean(validatingIdentityFunction, false);
}
// ------------------------------------------------------------------------
@After
public void checkNoTriggerThreadsRunning() {
// make sure that all the threads we trigger are shut down
......@@ -544,6 +557,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
// inject some elements
final int numElementsFirst = 700;
final int numElements = 1000;
for (int i = 0; i < numElementsFirst; i++) {
synchronized (lock) {
op.processElement(new StreamRecord<Integer>(i));
......@@ -560,6 +574,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
resultAtSnapshot = new ArrayList<>(out.getElements());
int afterSnapShot = out.getElements().size();
assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
assertTrue(afterSnapShot <= numElementsFirst);
}
// inject some random elements, which should not show up in the state
......@@ -584,7 +599,6 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
op.open();
// inject some more elements
final int numElements = 1000;
for (int i = numElementsFirst; i < numElements; i++) {
synchronized (lock) {
op.processElement(new StreamRecord<Integer>(i));
......@@ -725,6 +739,64 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
}
}
@Test
public void testKeyValueStateInWindowFunction() {
final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
try {
final CollectingOutput<Integer> out = new CollectingOutput<>(50);
final Object lock = new Object();
final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
StatefulFunction.globalCounts.clear();
// tumbling window that triggers every 20 milliseconds
AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
new AccumulatingProcessingTimeWindowOperator<>(
new StatefulFunction(), identitySelector,
IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50);
op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE), out);
op.open();
synchronized (lock) {
op.processElement(new StreamRecord<Integer>(1));
op.processElement(new StreamRecord<Integer>(2));
}
out.waitForNElements(2, 60000);
synchronized (lock) {
op.processElement(new StreamRecord<Integer>(1));
op.processElement(new StreamRecord<Integer>(2));
op.processElement(new StreamRecord<Integer>(1));
op.processElement(new StreamRecord<Integer>(1));
op.processElement(new StreamRecord<Integer>(2));
op.processElement(new StreamRecord<Integer>(2));
}
out.waitForNElements(8, 60000);
List<Integer> result = out.getElements();
assertEquals(8, result.size());
Collections.sort(result);
assertEquals(Arrays.asList(1, 1, 1, 1, 2, 2, 2, 2), result);
assertEquals(4, StatefulFunction.globalCounts.get(1).intValue());
assertEquals(4, StatefulFunction.globalCounts.get(2).intValue());
synchronized (lock) {
op.close();
}
op.dispose();
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
finally {
timerService.shutdown();
}
}
// ------------------------------------------------------------------------
private void assertInvalidParameter(long windowSize, long windowSlide) {
......@@ -771,6 +843,41 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
}
}
// ------------------------------------------------------------------------
private static class StatefulFunction extends RichWindowFunction<Integer, Integer, Integer, TimeWindow> {
// we use a concurrent map here even though there is no concurrency, to
// get "volatile" style access to entries
static final Map<Integer, Integer> globalCounts = new ConcurrentHashMap<>();
private OperatorState<Integer> state;
@Override
public void open(Configuration parameters) {
assertNotNull(getRuntimeContext());
state = getRuntimeContext().getKeyValueState("totalCount", Integer.class, 0);
}
@Override
public void apply(Integer key,
TimeWindow window,
Iterable<Integer> values,
Collector<Integer> out) throws Exception {
for (Integer i : values) {
// we need to update this state before emitting elements. Else, the test's main
// thread will have received all output elements before the state is updated and
// the checks may fail
state.update(state.value() + 1);
globalCounts.put(key, state.value());
out.collect(i);
}
}
}
// ------------------------------------------------------------------------
private static StreamTask<?, ?> createMockTask() {
StreamTask<?, ?> task = mock(StreamTask.class);
when(task.getAccumulatorMap()).thenReturn(new HashMap<String, Accumulator<?, ?>>());
......@@ -821,4 +928,11 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
return mockTask;
}
private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer) {
StreamConfig cfg = new StreamConfig(new Configuration());
cfg.setStatePartitioner(partitioner);
cfg.setStateKeySerializer(keySerializer);
return cfg;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册