提交 3b2ee23b 编写于 作者: S Stephan Ewen

[streaming] Integrate new checkpointed interface with StreamTask,...

[streaming] Integrate new checkpointed interface with StreamTask, StreamOperator, and PersistentKafkaSource
上级 ededb6b7
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.jobgraph.tasks;
import java.io.IOException;
/**
* A BarrierTransceiver describes an operator's barrier checkpointing behavior used for
* fault tolerance. In the most common case [[broadcastBarrier]] is being expected to be called
* periodically upon receiving a checkpoint barrier. Furthermore, a [[confirmBarrier]] method should
* be implemented and used for acknowledging a specific checkpoint checkpoint.
*/
public interface BarrierTransceiver {
/**
* A callback for notifying an operator of a new checkpoint barrier.
* @param barrierID
*/
public void broadcastBarrierFromSource(long barrierID);
/**
* A callback for confirming that a barrier checkpoint is complete
* @param barrierID
*/
public void confirmBarrier(long barrierID) throws IOException;
}
......@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks;
public interface CheckpointCommittingOperator {
void confirmCheckpoint(long checkpointId, long timestamp);
void confirmCheckpoint(long checkpointId, long timestamp) throws Exception;
}
......@@ -20,5 +20,5 @@ package org.apache.flink.runtime.jobgraph.tasks;
public interface CheckpointedOperator {
void triggerCheckpoint(long checkpointId, long timestamp);
void triggerCheckpoint(long checkpointId, long timestamp) throws Exception;
}
......@@ -33,6 +33,6 @@ public interface OperatorStateCarrier<T extends StateHandle<?>> {
*
* @param stateHandle The handle to the state.
*/
public void setInitialState(T stateHandle);
public void setInitialState(T stateHandle) throws Exception;
}
......@@ -18,23 +18,23 @@
package org.apache.flink.runtime.state;
import java.util.Map;
import java.io.Serializable;
/**
* A StateHandle that includes a map of operator states directly.
*/
public class LocalStateHandle implements StateHandle<Map<String, OperatorState<?>>> {
public class LocalStateHandle implements StateHandle<Serializable> {
private static final long serialVersionUID = 2093619217898039610L;
private final Map<String, OperatorState<?>> stateMap;
private final Serializable state;
public LocalStateHandle(Map<String,OperatorState<?>> state) {
this.stateMap = state;
public LocalStateHandle(Serializable state) {
this.state = state;
}
@Override
public Map<String,OperatorState<?>> getState() {
return stateMap;
public Serializable getState() {
return state;
}
}
......@@ -37,7 +37,9 @@ public class StateUtils {
* @param state The state handle.
* @param <T> Type bound for the
*/
public static <T extends StateHandle<?>> void setOperatorState(OperatorStateCarrier<?> op, StateHandle<?> state) {
public static <T extends StateHandle<?>> void setOperatorState(OperatorStateCarrier<?> op, StateHandle<?> state)
throws Exception
{
@SuppressWarnings("unchecked")
OperatorStateCarrier<T> typedOp = (OperatorStateCarrier<T>) op;
@SuppressWarnings("unchecked")
......
......@@ -45,7 +45,6 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
......@@ -57,6 +56,7 @@ import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.state.StateUtils;
import org.apache.flink.runtime.util.SerializedValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -872,23 +872,6 @@ public class Task implements Runnable {
};
executeAsyncCallRunnable(runnable, "Checkpoint Trigger");
}
else if (invokable instanceof BarrierTransceiver) {
final BarrierTransceiver barrierTransceiver = (BarrierTransceiver) invokable;
final Logger logger = LOG;
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
barrierTransceiver.broadcastBarrierFromSource(checkpointID);
}
catch (Throwable t) {
logger.error("Error while triggering checkpoint barriers", t);
}
}
};
executeAsyncCallRunnable(runnable, "Checkpoint Trigger");
}
else {
LOG.error("Task received a checkpoint request, but is not a checkpointing task - "
+ taskNameWithSubtask);
......
......@@ -27,15 +27,14 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.state.LocalStateHandle;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.util.SerializableObject;
import org.apache.flink.runtime.util.SerializedValue;
import org.junit.Test;
import org.mockito.Mockito;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
......@@ -52,7 +51,7 @@ public class CheckpointStateRestoreTest {
public void testSetState() {
try {
final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>(
new LocalStateHandle(Collections.<String,OperatorState<?>>emptyMap()));
new LocalStateHandle(new SerializableObject()));
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
......@@ -121,7 +120,7 @@ public class CheckpointStateRestoreTest {
public void testStateOnlyPartiallyAvailable() {
try {
final SerializedValue<StateHandle<?>> serializedState = new SerializedValue<StateHandle<?>>(
new LocalStateHandle(Collections.<String,OperatorState<?>>emptyMap()));
new LocalStateHandle(new SerializableObject()));
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
......
......@@ -28,6 +28,7 @@ import com.google.common.base.Preconditions;
import kafka.consumer.ConsumerConfig;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
import org.apache.flink.streaming.connectors.ConnectorSource;
import org.apache.flink.streaming.connectors.kafka.api.simple.iterator.KafkaConsumerIterator;
......@@ -51,7 +52,7 @@ import org.slf4j.LoggerFactory;
* @param <OUT>
* Type of the messages on the topic.
*/
public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> implements Checkpointed<HashMap<Integer, KafkaOffset>> {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(PersistentKafkaSource.class);
......@@ -202,13 +203,13 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
if (indexOfSubtask >= numberOfPartitions) {
LOG.info("Creating idle consumer because this subtask ({}) is higher than the number partitions ({})", indexOfSubtask + 1, numberOfPartitions);
iterator = new KafkaIdleConsumerIterator();
} else {
if (context.containsState("kafka")) {
}
else {
if (partitionOffsets != null) {
// we have restored state
LOG.info("Initializing PersistentKafkaSource from existing state.");
kafkaOffSetOperatorState = (OperatorState<Map<Integer, KafkaOffset>>) context.getState("kafka");
partitionOffsets = kafkaOffSetOperatorState.getState();
} else {
}
else {
LOG.info("No existing state found. Creating new");
partitionOffsets = new HashMap<Integer, KafkaOffset>();
......@@ -217,8 +218,6 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
}
kafkaOffSetOperatorState = new OperatorState<Map<Integer, KafkaOffset>>(partitionOffsets);
context.registerState("kafka", kafkaOffSetOperatorState);
}
iterator = new KafkaMultiplePartitionsIterator(topicId, partitionOffsets, kafkaTopicUtils, this.consumerConfig);
......@@ -272,4 +271,14 @@ public class PersistentKafkaSource<OUT> extends ConnectorSource<OUT> {
Properties props = (Properties) in.readObject();
consumerConfig = new ConsumerConfig(props);
}
@Override
public HashMap<Integer, KafkaOffset> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return new HashMap<Integer, KafkaOffset>(this.partitionOffsets);
}
@Override
public void restoreState(HashMap<Integer, KafkaOffset> state) {
this.partitionOffsets = state;
}
}
......@@ -18,7 +18,7 @@
package org.apache.flink.streaming.api.checkpoint;
import org.apache.flink.runtime.state.OperatorState;
import java.io.Serializable;
/**
* This method must be implemented by functions that have state that needs to be
......@@ -31,12 +31,14 @@ import org.apache.flink.runtime.state.OperatorState;
* continue to work and mutate the state, even while the state snapshot is being accessed,
* can implement the {@link org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously}
* interface.</p>
*
* @param <T> The type of the operator state.
*/
public interface Checkpointed {
public interface Checkpointed<T extends Serializable> {
/**
* Gets the current operator state as a checkpoint. The state must reflect all operations
* from all prior operations if this function.
* Gets the current state of the function of operator. The state must reflect the result of all
* prior invocations to this function.
*
* @param checkpointId The ID of the checkpoint.
* @param checkpointTimestamp The timestamp of the checkpoint, as derived by
......@@ -49,5 +51,13 @@ public interface Checkpointed {
* recovery), or to discard this checkpoint attempt and to continue running
* and to try again with the next checkpoint attempt.
*/
OperatorState<?> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception;
T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception;
/**
* Restores the state of the function or operator to that of a previous checkpoint.
* This method is invoked when a function is executed as part of a recovery run.
* *
* @param state The state to be restored.
*/
void restoreState(T state);
}
......@@ -26,6 +26,8 @@ import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.CheckpointCommitter;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.runtime.io.IndexedReaderIterator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer;
......@@ -107,16 +109,14 @@ public abstract class StreamOperator<IN, OUT> implements Serializable {
return nextRecord;
} catch (IOException e) {
if (isRunning) {
throw new RuntimeException("Could not read next record due to: "
+ StringUtils.stringifyException(e));
throw new RuntimeException("Could not read next record", e);
} else {
// Task already cancelled do nothing
return null;
}
} catch (IllegalStateException e) {
if (isRunning) {
throw new RuntimeException("Could not read next record due to: "
+ StringUtils.stringifyException(e));
throw new RuntimeException("Could not read next record", e);
} else {
// Task already cancelled do nothing
return null;
......@@ -215,4 +215,49 @@ public abstract class StreamOperator<IN, OUT> implements Serializable {
public Function getUserFunction() {
return userFunction;
}
// ------------------------------------------------------------------------
// Checkpoints and Checkpoint Confirmations
// ------------------------------------------------------------------------
// NOTE - ALL OF THIS CODE WORKS ONLY FOR THE FIRST OPERATOR IN THE CHAIN
// IT NEEDS TO BE EXTENDED TO SUPPORT CHAINS
public void restoreInitialState(Serializable state) throws Exception {
if (userFunction instanceof Checkpointed) {
setStateOnFunction(state, userFunction);
}
else {
throw new IllegalStateException("Trying to restore state of a non-checkpointed function");
}
}
public Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception {
if (userFunction instanceof Checkpointed) {
return ((Checkpointed<?>) userFunction).snapshotState(checkpointId, timestamp);
}
else {
return null;
}
}
public void confirmCheckpointCompleted(long checkpointId, long timestamp) throws Exception {
if (userFunction instanceof CheckpointCommitter) {
try {
((CheckpointCommitter) userFunction).commitCheckpoint(checkpointId);
}
catch (Exception e) {
throw new Exception("Error while confirming checkpoint " + checkpointId + " to the stream function", e);
}
}
}
private static <T extends Serializable> void setStateOnFunction(Serializable state, Function function) {
@SuppressWarnings("unchecked")
T typedState = (T) state;
@SuppressWarnings("unchecked")
Checkpointed<T> typedFunction = (Checkpointed<T>) function;
typedFunction.restoreState(typedState);
}
}
......@@ -88,8 +88,8 @@ public class OutputHandler<OUT> {
this.outerCollector = createChainedCollector(configuration);
}
public void broadcastBarrier(long id) throws IOException, InterruptedException {
StreamingSuperstep barrier = new StreamingSuperstep(id);
public void broadcastBarrier(long id, long timestamp) throws IOException, InterruptedException {
StreamingSuperstep barrier = new StreamingSuperstep(id, timestamp);
for (StreamOutput<?> streamOutput : outputMap.values()) {
streamOutput.broadcastEvent(barrier);
}
......
......@@ -18,18 +18,15 @@
package org.apache.flink.streaming.runtime.tasks;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.io.Serializable;
import org.apache.flink.runtime.event.task.TaskEvent;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
import org.apache.flink.runtime.state.LocalStateHandle;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.runtime.util.event.EventListener;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.ChainableStreamOperator;
......@@ -40,16 +37,18 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecordSerializer;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTaskContext<OUT>,
OperatorStateCarrier<LocalStateHandle>, CheckpointedOperator, CheckpointCommittingOperator,
BarrierTransceiver {
OperatorStateCarrier<LocalStateHandle>, CheckpointedOperator, CheckpointCommittingOperator {
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
private final Object checkpointLock = new Object();
private static int numTasks;
protected StreamConfig configuration;
......@@ -62,7 +61,6 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
protected volatile boolean isRunning = false;
private StreamingRuntimeContext context;
private Map<String, OperatorState<?>> states;
protected ClassLoader userClassLoader;
......@@ -90,32 +88,7 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
protected void initialize() {
this.userClassLoader = getUserCodeClassLoader();
this.configuration = new StreamConfig(getTaskConfiguration());
this.states = new HashMap<String, OperatorState<?>>();
this.context = createRuntimeContext(getEnvironment().getTaskName(), this.states);
}
@Override
public void broadcastBarrierFromSource(long id) {
// Only called at input vertices
if (LOG.isDebugEnabled()) {
LOG.debug("Received barrier from jobmanager: " + id);
}
actOnBarrier(id);
}
/**
* This method is called to confirm that a barrier has been fully processed.
* It sends an acknowledgment to the jobmanager. In the current version if
* there is user state it also checkpoints the state to the jobmanager.
*/
@Override
public void confirmBarrier(long barrierID) throws IOException {
if (configuration.getStateMonitoring() && !states.isEmpty()) {
getEnvironment().acknowledgeCheckpoint(barrierID, new LocalStateHandle(states));
}
else {
getEnvironment().acknowledgeCheckpoint(barrierID);
}
this.context = createRuntimeContext(getEnvironment().getTaskName());
}
public void setInputsOutputs() {
......@@ -136,11 +109,10 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
return instanceID;
}
public StreamingRuntimeContext createRuntimeContext(String taskName,
Map<String, OperatorState<?>> states) {
public StreamingRuntimeContext createRuntimeContext(String taskName) {
Environment env = getEnvironment();
return new StreamingRuntimeContext(taskName, env, getUserCodeClassLoader(),
getExecutionConfig(), states);
getExecutionConfig());
}
@Override
......@@ -272,62 +244,98 @@ public class StreamTask<IN, OUT> extends AbstractInvokable implements StreamTask
return this.superstepListener;
}
// ------------------------------------------------------------------------
// Checkpoint and Restore
// ------------------------------------------------------------------------
/**
* Method to be called when a barrier is received from all the input
* channels. It should broadcast the barrier to the output operators,
* checkpoint the state and send an ack.
*
* @param id
* Re-injects the user states into the map. Also set the state on the functions.
*/
private synchronized void actOnBarrier(long id) {
if (isRunning) {
try {
outputHandler.broadcastBarrier(id);
confirmBarrier(id);
if (LOG.isDebugEnabled()) {
LOG.debug("Superstep " + id + " processed: " + StreamTask.this);
}
} catch (Exception e) {
// Only throw any exception if the vertex is still running
if (isRunning) {
throw new RuntimeException(e);
}
}
}
}
@Override
public String toString() {
return getEnvironment().getTaskNameWithSubtasks();
public void setInitialState(LocalStateHandle stateHandle) throws Exception {
// here, we later resolve the state handle into the actual state by
// loading the state described by the handle from the backup store
Serializable state = stateHandle.getState();
streamOperator.restoreInitialState(state);
}
/**
* Re-injects the user states into the map
* This method is either called directly by the checkpoint coordinator, or called
* when all incoming channels have reported a barrier
*
* @param checkpointId
* @param timestamp
* @throws Exception
*/
@Override
public void setInitialState(LocalStateHandle stateHandle) {
this.states.putAll(stateHandle.getState());
public void triggerCheckpoint(long checkpointId, long timestamp) throws Exception {
synchronized (checkpointLock) {
if (isRunning) {
try {
LOG.info("Starting checkpoint " + checkpointId);
// first draw the state that should go into checkpoint
LocalStateHandle state;
try {
Serializable userState = streamOperator.getStateSnapshotFromFunction(checkpointId, timestamp);
state = userState == null ? null : new LocalStateHandle(userState);
}
catch (Exception e) {
throw new Exception("Error while drawing snapshot of the user state.");
}
// now emit the checkpoint barriers
outputHandler.broadcastBarrier(checkpointId, timestamp);
// now confirm the checkpoint
if (state == null) {
getEnvironment().acknowledgeCheckpoint(checkpointId);
} else {
getEnvironment().acknowledgeCheckpoint(checkpointId, state);
}
}
catch (Exception e) {
if (isRunning) {
throw e;
}
}
}
}
}
@Override
public void triggerCheckpoint(long checkpointId, long timestamp) {
broadcastBarrierFromSource(checkpointId);
public void confirmCheckpoint(long checkpointId, long timestamp) throws Exception {
// we do nothing here so far. this should call commit on the source function, for example
synchronized (checkpointLock) {
streamOperator.confirmCheckpointCompleted(checkpointId, timestamp);
}
}
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
@Override
public void confirmCheckpoint(long checkpointId, long timestamp) {
// we do nothing here so far. this should call commit on the source function, for example
public String toString() {
return getEnvironment().getTaskNameWithSubtasks();
}
// ------------------------------------------------------------------------
private class SuperstepEventListener implements EventListener<TaskEvent> {
@Override
public void onEvent(TaskEvent event) {
actOnBarrier(((StreamingSuperstep) event).getId());
try {
StreamingSuperstep sStep = (StreamingSuperstep) event;
triggerCheckpoint(sStep.getId(), sStep.getTimestamp());
}
catch (Exception e) {
throw new RuntimeException(
"Error triggering a checkpoint as the result of receiving checkpoint barrier", e);
}
}
}
}
......@@ -18,17 +18,13 @@
package org.apache.flink.streaming.runtime.tasks;
import java.util.Map;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.state.OperatorState;
/**
* Implementation of the {@link RuntimeContext}, created by runtime stream UDF
......@@ -37,65 +33,13 @@ import org.apache.flink.runtime.state.OperatorState;
public class StreamingRuntimeContext extends RuntimeUDFContext {
private final Environment env;
private final Map<String, OperatorState<?>> operatorStates;
public StreamingRuntimeContext(String name, Environment env, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, OperatorState<?>> operatorStates) {
ExecutionConfig executionConfig) {
super(name, env.getNumberOfSubtasks(), env.getIndexInSubtaskGroup(), userCodeClassLoader,
executionConfig, env.getDistributedCacheEntries());
this.env = env;
this.operatorStates = operatorStates;
}
/**
* Returns the operator state registered by the given name for the operator.
*
* @param name
* Name of the operator state to be returned.
* @return The operator state.
*/
public OperatorState<?> getState(String name) {
if (operatorStates == null) {
throw new RuntimeException("No state has been registered for this operator.");
} else {
OperatorState<?> state = operatorStates.get(name);
if (state != null) {
return state;
} else {
throw new RuntimeException("No state has been registered for the name: " + name);
}
}
}
/**
* Returns whether there is a state stored by the given name
*/
public boolean containsState(String name) {
return operatorStates.containsKey(name);
}
/**
* This is a beta feature </br></br> Register an operator state for this
* operator by the given name. This name can be used to retrieve the state
* during runtime using {@link StreamingRuntimeContext#getState(String)}. To
* obtain the {@link StreamingRuntimeContext} from the user-defined function
* use the {@link RichFunction#getRuntimeContext()} method.
*
* @param name
* The name of the operator state.
* @param state
* The state to be registered for this name.
*/
public void registerState(String name, OperatorState<?> state) {
if (state == null) {
throw new RuntimeException("Cannot register null state");
} else {
if (operatorStates.containsKey(name)) {
throw new RuntimeException("State is already registered");
} else {
operatorStates.put(name, state);
}
}
}
/**
......
......@@ -27,34 +27,57 @@ import org.apache.flink.runtime.event.task.TaskEvent;
public class StreamingSuperstep extends TaskEvent {
protected long id;
protected long timestamp;
public StreamingSuperstep() {
public StreamingSuperstep() {}
public StreamingSuperstep(long id, long timestamp) {
this.id = id;
this.timestamp = timestamp;
}
public StreamingSuperstep(long id) {
this.id = id;
public long getId() {
return id;
}
public long getTimestamp() {
return id;
}
// ------------------------------------------------------------------------
@Override
public void write(DataOutputView out) throws IOException {
out.writeLong(id);
out.writeLong(timestamp);
}
@Override
public void read(DataInputView in) throws IOException {
id = in.readLong();
timestamp = in.readLong();
}
// ------------------------------------------------------------------------
public long getId() {
return id;
@Override
public int hashCode() {
return (int) (id ^ (id >>> 32) ^ timestamp ^(timestamp >>> 32));
}
@Override
public boolean equals(Object other) {
if (other == null || !(other instanceof StreamingSuperstep)) {
return false;
} else {
return ((StreamingSuperstep) other).id == this.id;
}
else {
StreamingSuperstep that = (StreamingSuperstep) other;
return that.id == this.id && that.timestamp == this.timestamp;
}
}
@Override
public String toString() {
return String.format("StreamingSuperstep %d @ %d", id, timestamp);
}
}
......@@ -33,8 +33,8 @@ import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.util.event.EventListener;
import org.apache.flink.streaming.runtime.io.BarrierBuffer;
import org.apache.flink.streaming.runtime.tasks.StreamingSuperstep;
import org.junit.Test;
public class BarrierBufferTest {
......@@ -201,7 +201,7 @@ public class BarrierBufferTest {
}
protected static BufferOrEvent createSuperstep(long id, int channel) {
return new BufferOrEvent(new StreamingSuperstep(id), channel);
return new BufferOrEvent(new StreamingSuperstep(id, System.currentTimeMillis()), channel);
}
protected static BufferOrEvent createBuffer(int channel) {
......
......@@ -22,7 +22,7 @@ import org.apache.commons.io.FileUtils;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.OperatorState;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
......@@ -145,12 +145,15 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
}
}
public static class SleepyDurableGenerateSequence extends RichParallelSourceFunction<Long> {
public static class SleepyDurableGenerateSequence extends RichParallelSourceFunction<Long>
implements Checkpointed<Long> {
private static final long SLEEP_TIME = 50;
private final File coordinateDir;
private final long end;
private long collected;
public SleepyDurableGenerateSequence(File coordinateDir, long end) {
this.coordinateDir = coordinateDir;
......@@ -162,23 +165,10 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
public void run(Collector<Long> collector) throws Exception {
StreamingRuntimeContext context = (StreamingRuntimeContext) getRuntimeContext();
OperatorState<Long> collectedState;
if (context.containsState("collected")) {
collectedState = (OperatorState<Long>) context.getState("collected");
// if (collected == 0) {
// throw new RuntimeException("The state did not capture a completed checkpoint");
// }
}
else {
collectedState = new OperatorState<Long>(0L);
context.registerState("collected", collectedState);
}
final long stepSize = context.getNumberOfParallelSubtasks();
final long congruence = context.getIndexOfThisSubtask();
final long toCollect = (end % stepSize > congruence) ? (end / stepSize + 1) : (end / stepSize);
long collected = collectedState.getState();
final File proceedFile = new File(coordinateDir, PROCEED_MARKER_FILE);
boolean checkForProceedFile = true;
......@@ -196,13 +186,22 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
}
collector.collect(collected * stepSize + congruence);
collectedState.update(collected);
collected++;
}
}
@Override
public void cancel() {}
@Override
public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return collected;
}
@Override
public void restoreState(Long state) {
collected = state;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册