提交 3b2ee23b 编写于 作者: S Stephan Ewen

[streaming] Integrate new checkpointed interface with StreamTask,...

[streaming] Integrate new checkpointed interface with StreamTask, StreamOperator, and PersistentKafkaSource
上级 ededb6b7
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.jobgraph.tasks;
import java.io.IOException;
/**
* A BarrierTransceiver describes an operator's barrier checkpointing behavior used for
* fault tolerance. In the most common case [[broadcastBarrier]] is being expected to be called
* periodically upon receiving a checkpoint barrier. Furthermore, a [[confirmBarrier]] method should
* be implemented and used for acknowledging a specific checkpoint checkpoint.
*/
public interface BarrierTransceiver {
/**
* A callback for notifying an operator of a new checkpoint barrier.
* @param barrierID
*/
public void broadcastBarrierFromSource(long barrierID);
/**
* A callback for confirming that a barrier checkpoint is complete
* @param barrierID
*/
public void confirmBarrier(long barrierID) throws IOException;
}
...@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks; ...@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks;
public interface CheckpointCommittingOperator { public interface CheckpointCommittingOperator {
void confirmCheckpoint(long checkpointId, long timestamp); void confirmCheckpoint(long checkpointId, long timestamp) throws Exception;
} }
...@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks; ...@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks;
public interface CheckpointedOperator { public interface CheckpointedOperator {
void triggerCheckpoint(long checkpointId, long timestamp); void triggerCheckpoint(long checkpointId, long timestamp) throws Exception;
} }
...@@ -33,6 +33,6 @@ public interface OperatorStateCarrier<T extends StateHandle<?>> { ...@@ -33,6 +33,6 @@ public interface OperatorStateCarrier<T extends StateHandle<?>> {
* *
* @param stateHandle The handle to the state. * @param stateHandle The handle to the state.
*/ */
public void setInitialState(T stateHandle); public void setInitialState(T stateHandle) throws Exception;
} }
...@@ -18,23 +18,23 @@ ...@@ -18,23 +18,23 @@
package org.apache.flink.runtime.state; package org.apache.flink.runtime.state;
import java.util.Map; import java.io.Serializable;
/** /**
* A StateHandle that includes a map of operator states directly. * A StateHandle that includes a map of operator states directly.
*/ */
public class LocalStateHandle implements StateHandle<Map<String, OperatorState<?>>> { public class LocalStateHandle implements StateHandle<Serializable> {
private static final long serialVersionUID = 2093619217898039610L; private static final long serialVersionUID = 2093619217898039610L;
private final Map<String, OperatorState<?>> stateMap; private final Serializable state;
public LocalStateHandle(Map<String,OperatorState<?>> state) { public LocalStateHandle(Serializable state) {
this.stateMap = state; this.state = state;
} }
@Override @Override
public Map<String,OperatorState<?>> getState() { public Serializable getState() {
return stateMap; return state;
} }
} }
...@@ -37,7 +37,9 @@ public class StateUtils { ...@@ -37,7 +37,9 @@ public class StateUtils {
* @param state The state handle. * @param state The state handle.
* @param <T> Type bound for the * @param <T> Type bound for the
*/ */
public static <T extends StateHandle<?>> void setOperatorState(OperatorStateCarrier<?> op, StateHandle<?> state) { public static <T extends StateHandle<?>> void setOperatorState(OperatorStateCarrier<?> op, StateHandle<?> state)
throws Exception
{
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
OperatorStateCarrier<T> typedOp = (OperatorStateCarrier<T>) op; OperatorStateCarrier<T> typedOp = (OperatorStateCarrier<T>) op;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
......
...@@ -45,7 +45,6 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; ...@@ -45,7 +45,6 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator; import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier; import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
...@@ -57,6 +56,7 @@ import org.apache.flink.runtime.state.StateHandle; ...@@ -57,6 +56,7 @@ import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.state.StateUtils; import org.apache.flink.runtime.state.StateUtils;
import org.apache.flink.runtime.util.SerializedValue; import org.apache.flink.runtime.util.SerializedValue;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -872,23 +872,6 @@ public class Task implements Runnable { ...@@ -872,23 +872,6 @@ public class Task implements Runnable {
}; };
executeAsyncCallRunnable(runnable, "Checkpoint Trigger"); executeAsyncCallRunnable(runnable, "Checkpoint Trigger");
} }
else if (invokable instanceof BarrierTransceiver) {
final BarrierTransceiver barrierTransceiver = (BarrierTransceiver) invokable;
final Logger logger = LOG;
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
barrierTransceiver.broadcastBarrierFromSource(checkpointID);
}
catch (Throwable t) {
logger.error("Error while triggering checkpoint barriers", t);
}
}
};
executeAsyncCallRunnable(runnable, "Checkpoint Trigger");
}
else { else {
LOG.error("Task received a checkpoint request, but is not a checkpointing task - " LOG.error("Task received a checkpoint request, but is not a checkpointing task - "
+ taskNameWithSubtask); + taskNameWithSubtask);
......
...@@ -27,15 +27,14 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex; ...@@ -27,15 +27,14 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.state.LocalStateHandle; import org.apache.flink.runtime.state.LocalStateHandle;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.util.SerializableObject;
import org.apache.flink.runtime.util.SerializedValue; import org.apache.flink.runtime.util.SerializedValue;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
...@@ -52,7 +51,7 @@ public class CheckpointStateRestoreTest { ...@@ -52,7 +51,7 @@ public class CheckpointStateRestoreTest {
public void testSetState() { public void testSetState() {
try { try {
final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>( final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>(
new LocalStateHandle(Collections.<String,OperatorState<?>>emptyMap())); new LocalStateHandle(new SerializableObject()));
final JobID jid = new JobID(); final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID(); final JobVertexID statefulId = new JobVertexID();
...@@ -121,7 +120,7 @@ public class CheckpointStateRestoreTest { ...@@ -121,7 +120,7 @@ public class CheckpointStateRestoreTest {
public void testStateOnlyPartiallyAvailable() { public void testStateOnlyPartiallyAvailable() {
try { try {
final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>( final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>(
new LocalStateHandle(Collections.<String,OperatorState<?>>emptyMap())); new LocalStateHandle(new SerializableObject()));
final JobID jid = new JobID(); final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID(); final JobVertexID statefulId = new JobVertexID();
......
...@@ -28,6 +28,7 @@ import com.google.common.base.Preconditions; ...@@ -28,6 +28,7 @@ import com.google.common.base.Preconditions;
import kafka.consumer.ConsumerConfig; import kafka.consumer.ConsumerConfig;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.OperatorState; import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext; import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
import org.apache.flink.streaming.connectors.ConnectorSource; import org.apache.flink.streaming.connectors.ConnectorSource;
import org.apache.flink.streaming.connectors.kafka.api.simple.iterator.KafkaConsumerIterator; import org.apache.flink.streaming.connectors.kafka.api.simple.iterator.KafkaConsumerIterator;
...@@ -51,7 +52,7 @@ import org.slf4j.LoggerFactory; ...@@ -51,7 +52,7 @@ import org.slf4j.LoggerFactory;
* @param <OUT> * @param <OUT>
* Type of the messages on the topic. * Type of the messages on the topic.
*/ */
public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> { public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> implements Checkpointed<HashMap<Integer, KafkaOffset>> {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(PersistentKafkaSource.class); private static final Logger LOG = LoggerFactory.getLogger(PersistentKafkaSource.class);
...@@ -202,13 +203,13 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> { ...@@ -202,13 +203,13 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
if (indexOfSubtask >= numberOfPartitions) { if (indexOfSubtask >= numberOfPartitions) {
LOG.info("Creating idle consumer because this subtask ({}) is higher than the number partitions ({})", indexOfSubtask + 1, numberOfPartitions); LOG.info("Creating idle consumer because this subtask ({}) is higher than the number partitions ({})", indexOfSubtask + 1, numberOfPartitions);
iterator = new KafkaIdleConsumerIterator(); iterator = new KafkaIdleConsumerIterator();
} else { }
if (context.containsState("kafka")) { else {
if (partitionOffsets != null) {
// we have restored state
LOG.info("Initializing PersistentKafkaSource from existing state."); LOG.info("Initializing PersistentKafkaSource from existing state.");
kafkaOffSetOperatorState = (OperatorState<Map<Integer, KafkaOffset>>) context.getState("kafka"); }
else {
partitionOffsets = kafkaOffSetOperatorState.getState();
} else {
LOG.info("No existing state found. Creating new"); LOG.info("No existing state found. Creating new");
partitionOffsets = new HashMap<Integer, KafkaOffset>(); partitionOffsets = new HashMap<Integer, KafkaOffset>();
...@@ -217,8 +218,6 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> { ...@@ -217,8 +218,6 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
} }
kafkaOffSetOperatorState = new OperatorState<Map<Integer, KafkaOffset>>(partitionOffsets); kafkaOffSetOperatorState = new OperatorState<Map<Integer, KafkaOffset>>(partitionOffsets);
context.registerState("kafka", kafkaOffSetOperatorState);
} }
iterator = new KafkaMultiplePartitionsIterator(topicId, partitionOffsets, kafkaTopicUtils, this.consumerConfig); iterator = new KafkaMultiplePartitionsIterator(topicId, partitionOffsets, kafkaTopicUtils, this.consumerConfig);
...@@ -272,4 +271,14 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> { ...@@ -272,4 +271,14 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
Properties props = (Properties) in.readObject(); Properties props = (Properties) in.readObject();
consumerConfig = new ConsumerConfig(props); consumerConfig = new ConsumerConfig(props);
} }
@Override
public HashMap<Integer, KafkaOffset> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return new HashMap<Integer, KafkaOffset>(this.partitionOffsets);
}
@Override
public void restoreState(HashMap<Integer, KafkaOffset> state) {
this.partitionOffsets = state;
}
} }
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.flink.streaming.api.checkpoint; package org.apache.flink.streaming.api.checkpoint;
import org.apache.flink.runtime.state.OperatorState; import java.io.Serializable;
/** /**
* This method must be implemented by functions that have state that needs to be * This method must be implemented by functions that have state that needs to be
...@@ -31,12 +31,14 @@ import org.apache.flink.runtime.state.OperatorState; ...@@ -31,12 +31,14 @@ import org.apache.flink.runtime.state.OperatorState;
* continue to work and mutate the state, even while the state snapshot is being accessed, * continue to work and mutate the state, even while the state snapshot is being accessed,
* can implement the {@link org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously} * can implement the {@link org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously}
* interface.</p> * interface.</p>
*
* @param <T> The type of the operator state.
*/ */
public interface Checkpointed { public interface Checkpointed<T extends Serializable> {
/** /**
* Gets the current operator state as a checkpoint. The state must reflect all operations * Gets the current state of the function of operator. The state must reflect the result of all
* from all prior operations if this function. * prior invocations to this function.
* *
* @param checkpointId The ID of the checkpoint. * @param checkpointId The ID of the checkpoint.
* @param checkpointTimestamp The timestamp of the checkpoint, as derived by * @param checkpointTimestamp The timestamp of the checkpoint, as derived by
...@@ -49,5 +51,13 @@ public interface Checkpointed { ...@@ -49,5 +51,13 @@ public interface Checkpointed {
* recovery), or to discard this checkpoint attempt and to continue running * recovery), or to discard this checkpoint attempt and to continue running
* and to try again with the next checkpoint attempt. * and to try again with the next checkpoint attempt.
*/ */
OperatorState<?> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception; T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception;
/**
* Restores the state of the function or operator to that of a previous checkpoint.
* This method is invoked when a function is executed as part of a recovery run.
* *
* @param state The state to be restored.
*/
void restoreState(T state);
} }
...@@ -26,6 +26,8 @@ import org.apache.flink.api.common.functions.RuntimeContext; ...@@ -26,6 +26,8 @@ import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.CheckpointCommitter;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.runtime.io.IndexedReaderIterator; import org.apache.flink.streaming.runtime.io.IndexedReaderIterator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer; import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer;
...@@ -107,16 +109,14 @@ public abstract class StreamOperator<IN, OUT> implements Serializable { ...@@ -107,16 +109,14 @@ public abstract class StreamOperator<IN, OUT> implements Serializable {
return nextRecord; return nextRecord;
} catch (IOException e) { } catch (IOException e) {
if (isRunning) { if (isRunning) {
throw new RuntimeException("Could not read next record due to: " throw new RuntimeException("Could not read next record", e);
+ StringUtils.stringifyException(e));
} else { } else {
// Task already cancelled do nothing // Task already cancelled do nothing
return null; return null;
} }
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
if (isRunning) { if (isRunning) {
throw new RuntimeException("Could not read next record due to: " throw new RuntimeException("Could not read next record", e);
+ StringUtils.stringifyException(e));
} else { } else {
// Task already cancelled do nothing // Task already cancelled do nothing
return null; return null;
...@@ -215,4 +215,49 @@ public abstract class StreamOperator<IN, OUT> implements Serializable { ...@@ -215,4 +215,49 @@ public abstract class StreamOperator<IN, OUT> implements Serializable {
public Function getUserFunction() { public Function getUserFunction() {
return userFunction; return userFunction;
} }
// ------------------------------------------------------------------------
// Checkpoints and Checkpoint Confirmations
// ------------------------------------------------------------------------
// NOTE - ALL OF THIS CODE WORKS ONLY FOR THE FIRST OPERATOR IN THE CHAIN
// IT NEEDS TO BE EXTENDED TO SUPPORT CHAINS
public void restoreInitialState(Serializable state) throws Exception {
if (userFunction instanceof Checkpointed) {
setStateOnFunction(state, userFunction);
}
else {
throw new IllegalStateException("Trying to restore state of a non-checkpointed function");
}
}
public Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception {
if (userFunction instanceof Checkpointed) {
return ((Checkpointed<?>) userFunction).snapshotState(checkpointId, timestamp);
}
else {
return null;
}
}
public void confirmCheckpointCompleted(long checkpointId, long timestamp) throws Exception {
if (userFunction instanceof CheckpointCommitter) {
try {
((CheckpointCommitter) userFunction).commitCheckpoint(checkpointId);
}
catch (Exception e) {
throw new Exception("Error while confirming checkpoint " + checkpointId + " to the stream function", e);
}
}
}
private static <T extends Serializable> void setStateOnFunction(Serializable state, Function function) {
@SuppressWarnings("unchecked")
T typedState = (T) state;
@SuppressWarnings("unchecked")
Checkpointed<T> typedFunction = (Checkpointed<T>) function;
typedFunction.restoreState(typedState);
}
} }
...@@ -88,8 +88,8 @@ public class OutputHandler<OUT> { ...@@ -88,8 +88,8 @@ public class OutputHandler<OUT> {
this.outerCollector = createChainedCollector(configuration); this.outerCollector = createChainedCollector(configuration);
} }
public void broadcastBarrier(long id) throws IOException, InterruptedException { public void broadcastBarrier(long id, long timestamp) throws IOException, InterruptedException {
StreamingSuperstep barrier = new StreamingSuperstep(id); StreamingSuperstep barrier = new StreamingSuperstep(id, timestamp);
for (StreamOutput<?> streamOutput : outputMap.values()) { for (StreamOutput<?> streamOutput : outputMap.values()) {
streamOutput.broadcastEvent(barrier); streamOutput.broadcastEvent(barrier);
} }
......
...@@ -18,18 +18,15 @@ ...@@ -18,18 +18,15 @@
package org.apache.flink.streaming.runtime.tasks; package org.apache.flink.streaming.runtime.tasks;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.io.Serializable;
import java.util.Map;
import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.event.task.TaskEvent;
import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator; import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier; import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
import org.apache.flink.runtime.state.LocalStateHandle; import org.apache.flink.runtime.state.LocalStateHandle;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.runtime.util.event.EventListener; import org.apache.flink.runtime.util.event.EventListener;
import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.ChainableStreamOperator; import org.apache.flink.streaming.api.operators.ChainableStreamOperator;
...@@ -40,16 +37,18 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer; ...@@ -40,16 +37,18 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer;
import org.apache.flink.util.Collector; import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator; import org.apache.flink.util.MutableObjectIterator;
import org.apache.flink.util.StringUtils; import org.apache.flink.util.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTaskContext<OUT>, public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTaskContext<OUT>,
OperatorStateCarrier<LocalStateHandle>, CheckpointedOperator, CheckpointCommittingOperator, OperatorStateCarrier<LocalStateHandle>, CheckpointedOperator, CheckpointCommittingOperator {
BarrierTransceiver {
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
private final Object checkpointLock = new Object();
private static int numTasks; private static int numTasks;
protected StreamConfig configuration; protected StreamConfig configuration;
...@@ -62,7 +61,6 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask ...@@ -62,7 +61,6 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
protected volatile boolean isRunning = false; protected volatile boolean isRunning = false;
private StreamingRuntimeContext context; private StreamingRuntimeContext context;
private Map<String, OperatorState<?>> states;
protected ClassLoader userClassLoader; protected ClassLoader userClassLoader;
...@@ -90,32 +88,7 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask ...@@ -90,32 +88,7 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
protected void initialize() { protected void initialize() {
this.userClassLoader = getUserCodeClassLoader(); this.userClassLoader = getUserCodeClassLoader();
this.configuration = new StreamConfig(getTaskConfiguration()); this.configuration = new StreamConfig(getTaskConfiguration());
this.states = new HashMap<String, OperatorState<?>>(); this.context = createRuntimeContext(getEnvironment().getTaskName());
this.context = createRuntimeContext(getEnvironment().getTaskName(), this.states);
}
@Override
public void broadcastBarrierFromSource(long id) {
// Only called at input vertices
if (LOG.isDebugEnabled()) {
LOG.debug("Received barrier from jobmanager: " + id);
}
actOnBarrier(id);
}
/**
* This method is called to confirm that a barrier has been fully processed.
* It sends an acknowledgment to the jobmanager. In the current version if
* there is user state it also checkpoints the state to the jobmanager.
*/
@Override
public void confirmBarrier(long barrierID) throws IOException {
if (configuration.getStateMonitoring() && !states.isEmpty()) {
getEnvironment().acknowledgeCheckpoint(barrierID, new LocalStateHandle(states));
}
else {
getEnvironment().acknowledgeCheckpoint(barrierID);
}
} }
public void setInputsOutputs() { public void setInputsOutputs() {
...@@ -136,11 +109,10 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask ...@@ -136,11 +109,10 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
return instanceID; return instanceID;
} }
public StreamingRuntimeContext createRuntimeContext(String taskName, public StreamingRuntimeContext createRuntimeContext(String taskName) {
Map<String, OperatorState<?>> states) {
Environment env = getEnvironment(); Environment env = getEnvironment();
return new StreamingRuntimeContext(taskName, env, getUserCodeClassLoader(), return new StreamingRuntimeContext(taskName, env, getUserCodeClassLoader(),
getExecutionConfig(), states); getExecutionConfig());
} }
@Override @Override
...@@ -272,62 +244,98 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask ...@@ -272,62 +244,98 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
return this.superstepListener; return this.superstepListener;
} }
// ------------------------------------------------------------------------
// Checkpoint and Restore
// ------------------------------------------------------------------------
/** /**
* Method to be called when a barrier is received from all the input * Re-injects the user states into the map. Also set the state on the functions.
* channels. It should broadcast the barrier to the output operators, */
* checkpoint the state and send an ack. @Override
public void setInitialState(LocalStateHandle stateHandle) throws Exception {
// here, we later resolve the state handle into the actual state by
// loading the state described by the handle from the backup store
Serializable state = stateHandle.getState();
streamOperator.restoreInitialState(state);
}
/**
* This method is either called directly by the checkpoint coordinator, or called
* when all incoming channels have reported a barrier
* *
* @param id * @param checkpointId
* @param timestamp
* @throws Exception
*/ */
private synchronized void actOnBarrier(long id) { @Override
public void triggerCheckpoint(long checkpointId, long timestamp) throws Exception {
synchronized (checkpointLock) {
if (isRunning) { if (isRunning) {
try { try {
outputHandler.broadcastBarrier(id); LOG.info("Starting checkpoint " + checkpointId);
confirmBarrier(id);
if (LOG.isDebugEnabled()) { // first draw the state that should go into checkpoint
LOG.debug("Superstep " + id + " processed: " + StreamTask.this); LocalStateHandle state;
try {
Serializable userState = streamOperator.getStateSnapshotFromFunction(checkpointId, timestamp);
state = userState == null ? null : new LocalStateHandle(userState);
} }
} catch (Exception e) { catch (Exception e) {
// Only throw any exception if the vertex is still running throw new Exception("Error while drawing snapshot of the user state.");
if (isRunning) {
throw new RuntimeException(e);
} }
// now emit the checkpoint barriers
outputHandler.broadcastBarrier(checkpointId, timestamp);
// now confirm the checkpoint
if (state == null) {
getEnvironment().acknowledgeCheckpoint(checkpointId);
} else {
getEnvironment().acknowledgeCheckpoint(checkpointId, state);
} }
} }
catch (Exception e) {
if (isRunning) {
throw e;
}
} }
@Override
public String toString() {
return getEnvironment().getTaskNameWithSubtasks();
} }
/**
* Re-injects the user states into the map
*/
@Override
public void setInitialState(LocalStateHandle stateHandle) {
this.states.putAll(stateHandle.getState());
} }
@Override
public void triggerCheckpoint(long checkpointId, long timestamp) {
broadcastBarrierFromSource(checkpointId);
} }
@Override @Override
public void confirmCheckpoint(long checkpointId, long timestamp) { public void confirmCheckpoint(long checkpointId, long timestamp) throws Exception {
// we do nothing here so far. this should call commit on the source function, for example // we do nothing here so far. this should call commit on the source function, for example
synchronized (checkpointLock) {
streamOperator.confirmCheckpointCompleted(checkpointId, timestamp);
}
} }
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
@Override
public String toString() {
return getEnvironment().getTaskNameWithSubtasks();
}
// ------------------------------------------------------------------------
private class SuperstepEventListener implements EventListener<TaskEvent> { private class SuperstepEventListener implements EventListener<TaskEvent> {
@Override @Override
public void onEvent(TaskEvent event) { public void onEvent(TaskEvent event) {
actOnBarrier(((StreamingSuperstep) event).getId()); try {
StreamingSuperstep sStep = (StreamingSuperstep) event;
triggerCheckpoint(sStep.getId(), sStep.getTimestamp());
}
catch (Exception e) {
throw new RuntimeException(
"Error triggering a checkpoint as the result of receiving checkpoint barrier", e);
}
} }
} }
} }
...@@ -18,17 +18,13 @@ ...@@ -18,17 +18,13 @@
package org.apache.flink.streaming.runtime.tasks; package org.apache.flink.streaming.runtime.tasks;
import java.util.Map;
import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext; import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.operators.util.TaskConfig; import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.state.OperatorState;
/** /**
* Implementation of the {@link RuntimeContext}, created by runtime stream UDF * Implementation of the {@link RuntimeContext}, created by runtime stream UDF
...@@ -37,65 +33,13 @@ import org.apache.flink.runtime.state.OperatorState; ...@@ -37,65 +33,13 @@ import org.apache.flink.runtime.state.OperatorState;
public class StreamingRuntimeContext extends RuntimeUDFContext { public class StreamingRuntimeContext extends RuntimeUDFContext {
private final Environment env; private final Environment env;
private final Map<String, OperatorState<?>> operatorStates;
public StreamingRuntimeContext(String name, Environment env, ClassLoader userCodeClassLoader, public StreamingRuntimeContext(String name, Environment env, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, OperatorState<?>> operatorStates) { ExecutionConfig executionConfig) {
super(name, env.getNumberOfSubtasks(), env.getIndexInSubtaskGroup(), userCodeClassLoader, super(name, env.getNumberOfSubtasks(), env.getIndexInSubtaskGroup(), userCodeClassLoader,
executionConfig, env.getDistributedCacheEntries()); executionConfig, env.getDistributedCacheEntries());
this.env = env; this.env = env;
this.operatorStates = operatorStates;
}
/**
* Returns the operator state registered by the given name for the operator.
*
* @param name
* Name of the operator state to be returned.
* @return The operator state.
*/
public OperatorState<?> getState(String name) {
if (operatorStates == null) {
throw new RuntimeException("No state has been registered for this operator.");
} else {
OperatorState<?> state = operatorStates.get(name);
if (state != null) {
return state;
} else {
throw new RuntimeException("No state has been registered for the name: " + name);
}
}
}
/**
* Returns whether there is a state stored by the given name
*/
public boolean containsState(String name) {
return operatorStates.containsKey(name);
}
/**
* This is a beta feature </br></br> Register an operator state for this
* operator by the given name. This name can be used to retrieve the state
* during runtime using {@link StreamingRuntimeContext#getState(String)}. To
* obtain the {@link StreamingRuntimeContext} from the user-defined function
* use the {@link RichFunction#getRuntimeContext()} method.
*
* @param name
* The name of the operator state.
* @param state
* The state to be registered for this name.
*/
public void registerState(String name, OperatorState<?> state) {
if (state == null) {
throw new RuntimeException("Cannot register null state");
} else {
if (operatorStates.containsKey(name)) {
throw new RuntimeException("State is already registered");
} else {
operatorStates.put(name, state);
}
}
} }
/** /**
......
...@@ -27,34 +27,57 @@ import org.apache.flink.runtime.event.task.TaskEvent; ...@@ -27,34 +27,57 @@ import org.apache.flink.runtime.event.task.TaskEvent;
public class StreamingSuperstep extends TaskEvent { public class StreamingSuperstep extends TaskEvent {
protected long id; protected long id;
protected long timestamp;
public StreamingSuperstep() { public StreamingSuperstep() {}
public StreamingSuperstep(long id, long timestamp) {
this.id = id;
this.timestamp = timestamp;
} }
public StreamingSuperstep(long id) { public long getId() {
this.id = id; return id;
} }
public long getTimestamp() {
return id;
}
// ------------------------------------------------------------------------
@Override @Override
public void write(DataOutputView out) throws IOException { public void write(DataOutputView out) throws IOException {
out.writeLong(id); out.writeLong(id);
out.writeLong(timestamp);
} }
@Override @Override
public void read(DataInputView in) throws IOException { public void read(DataInputView in) throws IOException {
id = in.readLong(); id = in.readLong();
timestamp = in.readLong();
} }
public long getId() { // ------------------------------------------------------------------------
return id;
@Override
public int hashCode() {
return (int) (id ^ (id >>> 32) ^ timestamp ^(timestamp >>> 32));
} }
@Override
public boolean equals(Object other) { public boolean equals(Object other) {
if (other == null || !(other instanceof StreamingSuperstep)) { if (other == null || !(other instanceof StreamingSuperstep)) {
return false; return false;
} else {
return ((StreamingSuperstep) other).id == this.id;
} }
else {
StreamingSuperstep that = (StreamingSuperstep) other;
return that.id == this.id && that.timestamp == this.timestamp;
}
}
@Override
public String toString() {
return String.format("StreamingSuperstep %d @ %d", id, timestamp);
} }
} }
...@@ -33,8 +33,8 @@ import org.apache.flink.runtime.io.network.buffer.BufferRecycler; ...@@ -33,8 +33,8 @@ import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.util.event.EventListener; import org.apache.flink.runtime.util.event.EventListener;
import org.apache.flink.streaming.runtime.io.BarrierBuffer;
import org.apache.flink.streaming.runtime.tasks.StreamingSuperstep; import org.apache.flink.streaming.runtime.tasks.StreamingSuperstep;
import org.junit.Test; import org.junit.Test;
public class BarrierBufferTest { public class BarrierBufferTest {
...@@ -201,7 +201,7 @@ public class BarrierBufferTest { ...@@ -201,7 +201,7 @@ public class BarrierBufferTest {
} }
protected static BufferOrEvent createSuperstep(long id, int channel) { protected static BufferOrEvent createSuperstep(long id, int channel) {
return new BufferOrEvent(new StreamingSuperstep(id), channel); return new BufferOrEvent(new StreamingSuperstep(id, System.currentTimeMillis()), channel);
} }
protected static BufferOrEvent createBuffer(int channel) { protected static BufferOrEvent createBuffer(int channel) {
......
...@@ -22,7 +22,7 @@ import org.apache.commons.io.FileUtils; ...@@ -22,7 +22,7 @@ import org.apache.commons.io.FileUtils;
import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.OperatorState; import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
...@@ -145,13 +145,16 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur ...@@ -145,13 +145,16 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
} }
} }
public static class SleepyDurableGenerateSequence extends RichParallelSourceFunction<Long> { public static class SleepyDurableGenerateSequence extends RichParallelSourceFunction<Long>
implements Checkpointed<Long> {
private static final long SLEEP_TIME = 50; private static final long SLEEP_TIME = 50;
private final File coordinateDir; private final File coordinateDir;
private final long end; private final long end;
private long collected;
public SleepyDurableGenerateSequence(File coordinateDir, long end) { public SleepyDurableGenerateSequence(File coordinateDir, long end) {
this.coordinateDir = coordinateDir; this.coordinateDir = coordinateDir;
this.end = end; this.end = end;
...@@ -162,23 +165,10 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur ...@@ -162,23 +165,10 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
public void run(Collector<Long> collector) throws Exception { public void run(Collector<Long> collector) throws Exception {
StreamingRuntimeContext context = (StreamingRuntimeContext) getRuntimeContext(); StreamingRuntimeContext context = (StreamingRuntimeContext) getRuntimeContext();
OperatorState<Long> collectedState;
if (context.containsState("collected")) {
collectedState = (OperatorState<Long>) context.getState("collected");
// if (collected == 0) {
// throw new RuntimeException("The state did not capture a completed checkpoint");
// }
}
else {
collectedState = new OperatorState<Long>(0L);
context.registerState("collected", collectedState);
}
final long stepSize = context.getNumberOfParallelSubtasks(); final long stepSize = context.getNumberOfParallelSubtasks();
final long congruence = context.getIndexOfThisSubtask(); final long congruence = context.getIndexOfThisSubtask();
final long toCollect = (end % stepSize > congruence) ? (end / stepSize + 1) : (end / stepSize); final long toCollect = (end % stepSize > congruence) ? (end / stepSize + 1) : (end / stepSize);
long collected = collectedState.getState();
final File proceedFile = new File(coordinateDir, PROCEED_MARKER_FILE); final File proceedFile = new File(coordinateDir, PROCEED_MARKER_FILE);
boolean checkForProceedFile = true; boolean checkForProceedFile = true;
...@@ -196,13 +186,22 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur ...@@ -196,13 +186,22 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
} }
collector.collect(collected * stepSize + congruence); collector.collect(collected * stepSize + congruence);
collectedState.update(collected);
collected++; collected++;
} }
} }
@Override @Override
public void cancel() {} public void cancel() {}
@Override
public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return collected;
}
@Override
public void restoreState(Long state) {
collected = state;
}
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册