提交 cb2b76dc 编写于 作者: A Aljoscha Krettek

[FLINK-3974] Fix object reuse with multi-chaining

Before, a job would fail if object reuse was enabled and multiple
operators were chained to one upstream operator. Now, we create shallow
copies in BroadcastingOutputCollector and DirectedOutput if object reuse
is enabled.
上级 ee3c7a88
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.collector.selector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
/**
* Special version of {@link DirectedOutput} that performs a shallow copy of the
* {@link StreamRecord} to ensure that multi-chaining works correctly.
*/
public class CopyingDirectedOutput<OUT> extends DirectedOutput<OUT> {
@SuppressWarnings({"unchecked", "rawtypes"})
public CopyingDirectedOutput(
List<OutputSelector<OUT>> outputSelectors,
List<Tuple2<Output<StreamRecord<OUT>>, StreamEdge>> outputs) {
super(outputSelectors, outputs);
}
@Override
public void collect(StreamRecord<OUT> record) {
Set<Output<StreamRecord<OUT>>> selectedOutputs = selectOutputs(record);
if (selectedOutputs.isEmpty()) {
return;
}
Iterator<Output<StreamRecord<OUT>>> it = selectedOutputs.iterator();
while (true) {
Output<StreamRecord<OUT>> out = it.next();
if (it.hasNext()) {
// we don't have the last output
// perform a shallow copy
StreamRecord<OUT> shallowCopy = record.copy(record.getValue());
out.collect(shallowCopy);
} else {
// this is the last output
out.collect(record);
break;
}
}
}
}
......@@ -34,13 +34,13 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
public class DirectedOutput<OUT> implements Output<StreamRecord<OUT>> {
private final OutputSelector<OUT>[] outputSelectors;
protected final OutputSelector<OUT>[] outputSelectors;
private final Output<StreamRecord<OUT>>[] selectAllOutputs;
protected final Output<StreamRecord<OUT>>[] selectAllOutputs;
private final HashMap<String, Output<StreamRecord<OUT>>[]> outputMap;
protected final HashMap<String, Output<StreamRecord<OUT>>[]> outputMap;
private final Output<StreamRecord<OUT>>[] allOutputs;
protected final Output<StreamRecord<OUT>>[] allOutputs;
@SuppressWarnings({"unchecked", "rawtypes"})
......@@ -100,9 +100,8 @@ public class DirectedOutput<OUT> implements Output<StreamRecord<OUT>> {
}
}
@Override
public void collect(StreamRecord<OUT> record) {
Set<Output<StreamRecord<OUT>>> selectedOutputs = new HashSet<Output<StreamRecord<OUT>>>(selectAllOutputs.length);
protected Set<Output<StreamRecord<OUT>>> selectOutputs(StreamRecord<OUT> record) {
Set<Output<StreamRecord<OUT>>> selectedOutputs = new HashSet<>(selectAllOutputs.length);
Collections.addAll(selectedOutputs, selectAllOutputs);
for (OutputSelector<OUT> outputSelector : outputSelectors) {
......@@ -115,7 +114,14 @@ public class DirectedOutput<OUT> implements Output<StreamRecord<OUT>> {
}
}
}
return selectedOutputs;
}
@Override
public void collect(StreamRecord<OUT> record) {
Set<Output<StreamRecord<OUT>>> selectedOutputs = selectOutputs(record);
for (Output<StreamRecord<OUT>> out : selectedOutputs) {
out.collect(record);
}
......
......@@ -70,9 +70,7 @@ public abstract class AbstractStreamOperator<OUT>
// A sane default for most operators
protected ChainingStrategy chainingStrategy = ChainingStrategy.HEAD;
private boolean inputCopyDisabled = false;
// ---------------- runtime fields ------------------
/** The task that contains this operator (and other operators in the same chain) */
......@@ -323,19 +321,6 @@ public abstract class AbstractStreamOperator<OUT>
public final ChainingStrategy getChainingStrategy() {
return chainingStrategy;
}
@Override
public boolean isInputCopyingDisabled() {
return inputCopyDisabled;
}
/**
* Enable object-reuse for this operator instance. This overrides the setting in
* the {@link org.apache.flink.api.common.ExecutionConfig}
*/
public void disableInputCopy() {
this.inputCopyDisabled = true;
}
public class CountingOutput implements Output<StreamRecord<OUT>> {
private final Output<StreamRecord<OUT>> output;
......
......@@ -141,12 +141,6 @@ public interface StreamOperator<OUT> extends Serializable {
void setKeyContextElement2(StreamRecord<?> record) throws Exception;
/**
* An operator can return true here to disable copying of its input elements. This overrides
* the object-reuse setting on the {@link org.apache.flink.api.common.ExecutionConfig}
*/
boolean isInputCopyingDisabled();
ChainingStrategy getChainingStrategy();
void setChainingStrategy(ChainingStrategy strategy);
......
......@@ -32,6 +32,7 @@ import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput;
import org.apache.flink.streaming.api.collector.selector.DirectedOutput;
import org.apache.flink.streaming.api.collector.selector.OutputSelector;
import org.apache.flink.streaming.api.watermark.Watermark;
......@@ -232,13 +233,28 @@ public class OperatorChain<OUT> {
for (int i = 0; i < allOutputs.size(); i++) {
asArray[i] = allOutputs.get(i).f0;
}
return new BroadcastingOutputCollector<T>(asArray);
// This is the inverse of creating the normal ChainingOutput.
// If the chaining output does not copy we need to copy in the broadcast output,
// otherwise multi-chaining would not work correctly.
if (containingTask.getExecutionConfig().isObjectReuseEnabled()) {
return new CopyingBroadcastingOutputCollector<>(asArray);
} else {
return new BroadcastingOutputCollector<>(asArray);
}
}
}
else {
// selector present, more complex routing necessary
return new DirectedOutput<T>(selectors, allOutputs);
// This is the inverse of creating the normal ChainingOutput.
// If the chaining output does not copy we need to copy in the broadcast output,
// otherwise multi-chaining would not work correctly.
if (containingTask.getExecutionConfig().isObjectReuseEnabled()) {
return new CopyingDirectedOutput<>(selectors, allOutputs);
} else {
return new DirectedOutput<>(selectors, allOutputs);
}
}
}
......@@ -261,12 +277,12 @@ public class OperatorChain<OUT> {
allOperators.add(chainedOperator);
if (containingTask.getExecutionConfig().isObjectReuseEnabled() || chainedOperator.isInputCopyingDisabled()) {
return new ChainingOutput<IN>(chainedOperator);
if (containingTask.getExecutionConfig().isObjectReuseEnabled()) {
return new ChainingOutput<>(chainedOperator);
}
else {
TypeSerializer<IN> inSerializer = operatorConfig.getTypeSerializerIn1(userCodeClassloader);
return new CopyingChainingOutput<IN>(chainedOperator, inSerializer);
return new CopyingChainingOutput<>(chainedOperator, inSerializer);
}
}
......@@ -339,7 +355,7 @@ public class OperatorChain<OUT> {
}
}
private static class CopyingChainingOutput<T> extends ChainingOutput<T> {
private static final class CopyingChainingOutput<T> extends ChainingOutput<T> {
private final TypeSerializer<T> serializer;
......@@ -362,9 +378,9 @@ public class OperatorChain<OUT> {
}
}
private static final class BroadcastingOutputCollector<T> implements Output<StreamRecord<T>> {
private static class BroadcastingOutputCollector<T> implements Output<StreamRecord<T>> {
private final Output<StreamRecord<T>>[] outputs;
protected final Output<StreamRecord<T>>[] outputs;
public BroadcastingOutputCollector(Output<StreamRecord<T>>[] outputs) {
this.outputs = outputs;
......@@ -391,4 +407,28 @@ public class OperatorChain<OUT> {
}
}
}
/**
* Special version of {@link BroadcastingOutputCollector} that performs a shallow copy of the
* {@link StreamRecord} to ensure that multi-chaining works correctly.
*/
private static final class CopyingBroadcastingOutputCollector<T> extends BroadcastingOutputCollector<T> {
public CopyingBroadcastingOutputCollector(Output<StreamRecord<T>>[] outputs) {
super(outputs);
}
@Override
public void collect(StreamRecord<T> record) {
for (int i = 0; i < outputs.length - 1; i++) {
Output<StreamRecord<T>> output = outputs[i];
StreamRecord<T> shallowCopy = record.copy(record.getValue());
output.collect(shallowCopy);
}
// don't copy for the last output
outputs[outputs.length - 1].collect(record);
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.runtime.operators;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.collector.selector.OutputSelector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SplitStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OperatorChain;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.contains;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.hamcrest.MatcherAssert.assertThat;
/**
* Tests for stream operator chaining behaviour.
*/
public class StreamOperatorChainingTest {
// We have to use static fields because the sink functions will go through serialization
private static List<String> sink1Results;
private static List<String> sink2Results;
private static List<String> sink3Results;
@Test
public void testMultiChainingWithObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().enableObjectReuse();
testMultiChaining(env);
}
@Test
public void testMultiChainingWithoutObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().disableObjectReuse();
testMultiChaining(env);
}
/**
* Verify that multi-chaining works.
*/
private void testMultiChaining(StreamExecutionEnvironment env) throws Exception {
// the actual elements will not be used
DataStream<Integer> input = env.fromElements(1,2,3);
sink1Results = new ArrayList<>();
sink2Results = new ArrayList<>();
input = input
.map(new MapFunction<Integer, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Integer map(Integer value) throws Exception {
return value;
}
});
input
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink1Results.add(value);
}
});
input
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "Second: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink2Results.add(value);
}
});
// be build our own StreamTask and OperatorChain
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2);
JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
Configuration configuration = chainedVertex.getConfiguration();
StreamConfig streamConfig = new StreamConfig(configuration);
StreamMap<Integer, Integer> headOperator =
streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader());
StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
createMockTask(streamConfig, chainedVertex.getName());
OperatorChain<Integer> operatorChain = new OperatorChain<>(
mockTask,
headOperator,
mock(AccumulatorRegistry.Reporter.class));
headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
if (operator != null) {
operator.open();
}
}
headOperator.processElement(new StreamRecord<>(1));
headOperator.processElement(new StreamRecord<>(2));
headOperator.processElement(new StreamRecord<>(3));
assertThat(sink1Results, contains("First: 1", "First: 2", "First: 3"));
assertThat(sink2Results, contains("Second: 1", "Second: 2", "Second: 3"));
}
@Test
public void testMultiChainingWithSplitWithObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().enableObjectReuse();
testMultiChainingWithSplit(env);
}
@Test
public void testMultiChainingWithSplitWithoutObjectReuse() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.getConfig().disableObjectReuse();
testMultiChainingWithSplit(env);
}
/**
* Verify that multi-chaining works with object reuse enabled.
*/
private void testMultiChainingWithSplit(StreamExecutionEnvironment env) throws Exception {
// the actual elements will not be used
DataStream<Integer> input = env.fromElements(1,2,3);
sink1Results = new ArrayList<>();
sink2Results = new ArrayList<>();
sink3Results = new ArrayList<>();
input = input
.map(new MapFunction<Integer, Integer>(){
private static final long serialVersionUID = 1L;
@Override
public Integer map(Integer value) throws Exception {
return value;
}
});
SplitStream<Integer> split = input.split(new OutputSelector<Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Iterable<String> select(Integer value) {
if (value.equals(1)) {
return Collections.singletonList("one");
} else {
return Collections.singletonList("other");
}
}
});
split.select("one")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First 1: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink1Results.add(value);
}
});
split.select("one")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "First 2: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink2Results.add(value);
}
});
split.select("other")
.map(new MapFunction<Integer, String>() {
private static final long serialVersionUID = 1L;
@Override
public String map(Integer value) throws Exception {
return "Second: " + value;
}
})
.addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
@Override
public void invoke(String value) throws Exception {
sink3Results.add(value);
}
});
// be build our own StreamTask and OperatorChain
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
Assert.assertTrue(jobGraph.getVerticesSortedTopologicallyFromSources().size() == 2);
JobVertex chainedVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
Configuration configuration = chainedVertex.getConfiguration();
StreamConfig streamConfig = new StreamConfig(configuration);
StreamMap<Integer, Integer> headOperator =
streamConfig.getStreamOperator(Thread.currentThread().getContextClassLoader());
StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
createMockTask(streamConfig, chainedVertex.getName());
OperatorChain<Integer> operatorChain = new OperatorChain<>(
mockTask,
headOperator,
mock(AccumulatorRegistry.Reporter.class));
headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
if (operator != null) {
operator.open();
}
}
headOperator.processElement(new StreamRecord<>(1));
headOperator.processElement(new StreamRecord<>(2));
headOperator.processElement(new StreamRecord<>(3));
assertThat(sink1Results, contains("First 1: 1"));
assertThat(sink2Results, contains("First 2: 1"));
assertThat(sink3Results, contains("Second: 2", "Second: 3"));
}
private <IN, OT extends StreamOperator<IN>> StreamTask<IN, OT> createMockTask(StreamConfig streamConfig, String taskName) {
final Object checkpointLock = new Object();
final Environment env = new MockEnvironment(taskName, 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
@SuppressWarnings("unchecked")
StreamTask<IN, OT> mockTask = mock(StreamTask.class);
when(mockTask.getName()).thenReturn("Mock Task");
when(mockTask.getCheckpointLock()).thenReturn(checkpointLock);
when(mockTask.getConfiguration()).thenReturn(streamConfig);
when(mockTask.getEnvironment()).thenReturn(env);
when(mockTask.getExecutionConfig()).thenReturn(new ExecutionConfig().enableObjectReuse());
try {
doAnswer(new Answer<AbstractStateBackend>() {
@Override
public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
final String operatorIdentifier = (String) invocationOnMock.getArguments()[0];
final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1];
MemoryStateBackend backend = MemoryStateBackend.create();
backend.initializeForJob(env, operatorIdentifier, keySerializer);
return backend;
}
}).when(mockTask).createStateBackend(any(String.class), any(TypeSerializer.class));
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
return mockTask;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册