From cb2b76dc4c999398133c392c9fa4a8ef82e90fd5 Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Tue, 14 Jun 2016 12:18:35 +0200 Subject: [PATCH] [FLINK-3974] Fix object reuse with multi-chaining Before, a job would fail if object reuse was enabled and multiple operators were chained to one upstream operator. Now, we create shallow copies in BroadcastingOutputCollector and DirectedOutput if object reuse is enabled. --- .../selector/CopyingDirectedOutput.java | 67 ++++ .../collector/selector/DirectedOutput.java | 22 +- .../api/operators/AbstractStreamOperator.java | 17 +- .../api/operators/StreamOperator.java | 6 - .../runtime/tasks/OperatorChain.java | 58 ++- .../operators/StreamOperatorChainingTest.java | 354 ++++++++++++++++++ 6 files changed, 485 insertions(+), 39 deletions(-) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/CopyingDirectedOutput.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/CopyingDirectedOutput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/CopyingDirectedOutput.java new file mode 100644 index 00000000000..5f7e2787e3b --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/CopyingDirectedOutput.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.collector.selector; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import java.util.Iterator; +import java.util.List; +import java.util.Set; + + +/** + * Special version of {@link DirectedOutput} that performs a shallow copy of the + * {@link StreamRecord} to ensure that multi-chaining works correctly. + */ +public class CopyingDirectedOutput extends DirectedOutput { + + @SuppressWarnings({"unchecked", "rawtypes"}) + public CopyingDirectedOutput( + List> outputSelectors, + List>, StreamEdge>> outputs) { + super(outputSelectors, outputs); + } + + @Override + public void collect(StreamRecord record) { + Set>> selectedOutputs = selectOutputs(record); + + if (selectedOutputs.isEmpty()) { + return; + } + + Iterator>> it = selectedOutputs.iterator(); + + while (true) { + Output> out = it.next(); + if (it.hasNext()) { + // we don't have the last output + // perform a shallow copy + StreamRecord shallowCopy = record.copy(record.getValue()); + out.collect(shallowCopy); + } else { + // this is the last output + out.collect(record); + break; + } + } + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/DirectedOutput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/DirectedOutput.java index 52c50b362bd..8346013b7cc 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/DirectedOutput.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/collector/selector/DirectedOutput.java @@ -34,13 +34,13 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; public class DirectedOutput implements Output> { - private final OutputSelector[] outputSelectors; + protected final OutputSelector[] outputSelectors; - private final Output>[] selectAllOutputs; + protected final Output>[] selectAllOutputs; - private final HashMap>[]> outputMap; + protected final HashMap>[]> outputMap; - private final Output>[] allOutputs; + protected final Output>[] allOutputs; @SuppressWarnings({"unchecked", "rawtypes"}) @@ -100,9 +100,8 @@ public class DirectedOutput implements Output> { } } - @Override - public void collect(StreamRecord record) { - Set>> selectedOutputs = new HashSet>>(selectAllOutputs.length); + protected Set>> selectOutputs(StreamRecord record) { + Set>> selectedOutputs = new HashSet<>(selectAllOutputs.length); Collections.addAll(selectedOutputs, selectAllOutputs); for (OutputSelector outputSelector : outputSelectors) { @@ -115,7 +114,14 @@ public class DirectedOutput implements Output> { } } } - + + return selectedOutputs; + } + + @Override + public void collect(StreamRecord record) { + Set>> selectedOutputs = selectOutputs(record); + for (Output> out : selectedOutputs) { out.collect(record); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index dc7bbdb8a97..3efc469506c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -70,9 +70,7 @@ public abstract class AbstractStreamOperator // A sane default for most operators protected ChainingStrategy chainingStrategy = ChainingStrategy.HEAD; - - private boolean inputCopyDisabled = false; - + // ---------------- runtime fields ------------------ /** The task that contains this operator (and other operators in the same chain) */ @@ -323,19 +321,6 @@ public abstract class AbstractStreamOperator public final ChainingStrategy getChainingStrategy() { return chainingStrategy; } - - @Override - public boolean isInputCopyingDisabled() { - return inputCopyDisabled; - } - - /** - * Enable object-reuse for this operator instance. This overrides the setting in - * the {@link org.apache.flink.api.common.ExecutionConfig} - */ - public void disableInputCopy() { - this.inputCopyDisabled = true; - } public class CountingOutput implements Output> { private final Output> output; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 9ed715e271d..4572ef17ada 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -141,12 +141,6 @@ public interface StreamOperator extends Serializable { void setKeyContextElement2(StreamRecord record) throws Exception; - /** - * An operator can return true here to disable copying of its input elements. This overrides - * the object-reuse setting on the {@link org.apache.flink.api.common.ExecutionConfig} - */ - boolean isInputCopyingDisabled(); - ChainingStrategy getChainingStrategy(); void setChainingStrategy(ChainingStrategy strategy); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java index 761aa37baf9..0e24516eacf 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput; import org.apache.flink.streaming.api.collector.selector.DirectedOutput; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.watermark.Watermark; @@ -232,13 +233,28 @@ public class OperatorChain { for (int i = 0; i < allOutputs.size(); i++) { asArray[i] = allOutputs.get(i).f0; } - - return new BroadcastingOutputCollector(asArray); + + // This is the inverse of creating the normal ChainingOutput. + // If the chaining output does not copy we need to copy in the broadcast output, + // otherwise multi-chaining would not work correctly. + if (containingTask.getExecutionConfig().isObjectReuseEnabled()) { + return new CopyingBroadcastingOutputCollector<>(asArray); + } else { + return new BroadcastingOutputCollector<>(asArray); + } } } else { // selector present, more complex routing necessary - return new DirectedOutput(selectors, allOutputs); + + // This is the inverse of creating the normal ChainingOutput. + // If the chaining output does not copy we need to copy in the broadcast output, + // otherwise multi-chaining would not work correctly. + if (containingTask.getExecutionConfig().isObjectReuseEnabled()) { + return new CopyingDirectedOutput<>(selectors, allOutputs); + } else { + return new DirectedOutput<>(selectors, allOutputs); + } } } @@ -261,12 +277,12 @@ public class OperatorChain { allOperators.add(chainedOperator); - if (containingTask.getExecutionConfig().isObjectReuseEnabled() || chainedOperator.isInputCopyingDisabled()) { - return new ChainingOutput(chainedOperator); + if (containingTask.getExecutionConfig().isObjectReuseEnabled()) { + return new ChainingOutput<>(chainedOperator); } else { TypeSerializer inSerializer = operatorConfig.getTypeSerializerIn1(userCodeClassloader); - return new CopyingChainingOutput(chainedOperator, inSerializer); + return new CopyingChainingOutput<>(chainedOperator, inSerializer); } } @@ -339,7 +355,7 @@ public class OperatorChain { } } - private static class CopyingChainingOutput extends ChainingOutput { + private static final class CopyingChainingOutput extends ChainingOutput { private final TypeSerializer serializer; @@ -362,9 +378,9 @@ public class OperatorChain { } } - private static final class BroadcastingOutputCollector implements Output> { + private static class BroadcastingOutputCollector implements Output> { - private final Output>[] outputs; + protected final Output>[] outputs; public BroadcastingOutputCollector(Output>[] outputs) { this.outputs = outputs; @@ -391,4 +407,28 @@ public class OperatorChain { } } } + + /** + * Special version of {@link BroadcastingOutputCollector} that performs a shallow copy of the + * {@link StreamRecord} to ensure that multi-chaining works correctly. + */ + private static final class CopyingBroadcastingOutputCollector extends BroadcastingOutputCollector { + + public CopyingBroadcastingOutputCollector(Output>[] outputs) { + super(outputs); + } + + @Override + public void collect(StreamRecord record) { + + for (int i = 0; i < outputs.length - 1; i++) { + Output> output = outputs[i]; + StreamRecord shallowCopy = record.copy(record.getValue()); + output.collect(shallowCopy); + } + + // don't copy for the last output + outputs[outputs.length - 1].collect(record); + } + } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java new file mode 100644 index 00000000000..3b201dc7e2c --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.streaming.runtime.operators; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.accumulators.AccumulatorRegistry; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.operators.testutils.MockEnvironment; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.streaming.api.collector.selector.OutputSelector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SplitStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.contains; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Tests for stream operator chaining behaviour. + */ +public class StreamOperatorChainingTest { + + // We have to use static fields because the sink functions will go through serialization + private static List sink1Results; + private static List sink2Results; + private static List sink3Results; + + @Test + public void testMultiChainingWithObjectReuse() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().enableObjectReuse(); + + testMultiChaining(env); + } + + @Test + public void testMultiChainingWithoutObjectReuse() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().disableObjectReuse(); + + testMultiChaining(env); + } + + /** + * Verify that multi-chaining works. + */ + private void testMultiChaining(StreamExecutionEnvironment env) throws Exception { + + // the actual elements will not be used + DataStream input = env.fromElements(1,2,3); + + sink1Results = new ArrayList<>(); + sink2Results = new ArrayList<>(); + + input = input + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + }); + + input + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public String map(Integer value) throws Exception { + return "First: " + value; + } + }) + .addSink(new SinkFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void invoke(String value) throws Exception { + sink1Results.add(value); + } + }); + + input + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public String map(Integer value) throws Exception { + return "Second: " + value; + } + }) + .addSink(new SinkFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void invoke(String value) throws Exception { + sink2Results.add(value); + } + }); + + // be build our own StreamTask and OperatorChain + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + + Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2); + + JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1); + + Configuration configuration = chainedVertex.getConfiguration(); + + StreamConfig streamConfig = new StreamConfig(configuration); + + StreamMap headOperator = + streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader()); + + StreamTask> mockTask = + createMockTask(streamConfig, chainedVertex.getName()); + + OperatorChain operatorChain = new OperatorChain<>( + mockTask, + headOperator, + mock(AccumulatorRegistry.Reporter.class)); + + headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint()); + + for (StreamOperator operator : operatorChain.getAllOperators()) { + if (operator != null) { + operator.open(); + } + } + + headOperator.processElement(new StreamRecord<>(1)); + headOperator.processElement(new StreamRecord<>(2)); + headOperator.processElement(new StreamRecord<>(3)); + + assertThat(sink1Results, contains("First: 1", "First: 2", "First: 3")); + assertThat(sink2Results, contains("Second: 1", "Second: 2", "Second: 3")); + } + + @Test + public void testMultiChainingWithSplitWithObjectReuse() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().enableObjectReuse(); + + testMultiChainingWithSplit(env); + } + + @Test + public void testMultiChainingWithSplitWithoutObjectReuse() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().disableObjectReuse(); + + testMultiChainingWithSplit(env); + } + + /** + * Verify that multi-chaining works with object reuse enabled. + */ + private void testMultiChainingWithSplit(StreamExecutionEnvironment env) throws Exception { + + // the actual elements will not be used + DataStream input = env.fromElements(1,2,3); + + sink1Results = new ArrayList<>(); + sink2Results = new ArrayList<>(); + sink3Results = new ArrayList<>(); + + input = input + .map(new MapFunction(){ + private static final long serialVersionUID = 1L; + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + }); + + SplitStream split = input.split(new OutputSelector() { + private static final long serialVersionUID = 1L; + + @Override + public Iterable select(Integer value) { + if (value.equals(1)) { + return Collections.singletonList("one"); + } else { + return Collections.singletonList("other"); + } + } + }); + + split.select("one") + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public String map(Integer value) throws Exception { + return "First 1: " + value; + } + }) + .addSink(new SinkFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void invoke(String value) throws Exception { + sink1Results.add(value); + } + }); + + split.select("one") + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public String map(Integer value) throws Exception { + return "First 2: " + value; + } + }) + .addSink(new SinkFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void invoke(String value) throws Exception { + sink2Results.add(value); + } + }); + + split.select("other") + .map(new MapFunction() { + private static final long serialVersionUID = 1L; + + @Override + public String map(Integer value) throws Exception { + return "Second: " + value; + } + }) + .addSink(new SinkFunction() { + private static final long serialVersionUID = 1L; + + @Override + public void invoke(String value) throws Exception { + sink3Results.add(value); + } + }); + + // be build our own StreamTask and OperatorChain + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + + Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2); + + JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1); + + Configuration configuration = chainedVertex.getConfiguration(); + + StreamConfig streamConfig = new StreamConfig(configuration); + + StreamMap headOperator = + streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader()); + + StreamTask> mockTask = + createMockTask(streamConfig, chainedVertex.getName()); + + OperatorChain operatorChain = new OperatorChain<>( + mockTask, + headOperator, + mock(AccumulatorRegistry.Reporter.class)); + + headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint()); + + for (StreamOperator operator : operatorChain.getAllOperators()) { + if (operator != null) { + operator.open(); + } + } + + headOperator.processElement(new StreamRecord<>(1)); + headOperator.processElement(new StreamRecord<>(2)); + headOperator.processElement(new StreamRecord<>(3)); + + assertThat(sink1Results, contains("First 1: 1")); + assertThat(sink2Results, contains("First 2: 1")); + assertThat(sink3Results, contains("Second: 2", "Second: 3")); + } + + private > StreamTask createMockTask(StreamConfig streamConfig, String taskName) { + final Object checkpointLock = new Object(); + final Environment env = new MockEnvironment(taskName, 3 * 1024 * 1024, new MockInputSplitProvider(), 1024); + + @SuppressWarnings("unchecked") + StreamTask mockTask = mock(StreamTask.class); + when(mockTask.getName()).thenReturn("Mock Task"); + when(mockTask.getCheckpointLock()).thenReturn(checkpointLock); + when(mockTask.getConfiguration()).thenReturn(streamConfig); + when(mockTask.getEnvironment()).thenReturn(env); + when(mockTask.getExecutionConfig()).thenReturn(new ExecutionConfig().enableObjectReuse()); + + try { + doAnswer(new Answer() { + @Override + public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { + final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; + final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[1]; + MemoryStateBackend backend = MemoryStateBackend.create(); + backend.initializeForJob(env, operatorIdentifier, keySerializer); + return backend; + } + }).when(mockTask).createStateBackend(any(String.class), any(TypeSerializer.class)); + } catch (Exception e) { + throw new RuntimeException(e.getMessage(), e); + } + + return mockTask; + } + +} -- GitLab