提交 bf8c8e54 编写于 作者: R Robert Metzger 提交者: Stephan Ewen

[FLINK-2543] [core] Fix user object deserialization for file-based state handles.

Send exceptions from JM --> JC in serialized form.
Exceptions send from the JobManager to the JobClient were relying on
Akka's JavaSerialization, which does not have access to the user code classloader.

This closes #1048
上级 554b77bc
......@@ -27,7 +27,6 @@ import org.apache.flink.api.common.PlanExecutor;
import org.apache.flink.api.common.Program;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.optimizer.DataStatistics;
......@@ -176,8 +175,7 @@ public class LocalExecutor extends PlanExecutor {
JobGraph jobGraph = jgg.compileJobGraph(op);
boolean sysoutPrint = isPrintingStatusDuringExecution();
SerializedJobExecutionResult result = flink.submitJobAndWait(jobGraph,sysoutPrint);
return result.toJobExecutionResult(ClassLoader.getSystemClassLoader());
return flink.submitJobAndWait(jobGraph, sysoutPrint);
}
finally {
if (shutDownAtEnd) {
......
......@@ -53,7 +53,6 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.client.JobClient;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.instance.ActorGateway;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobmanager.JobManager;
......@@ -425,15 +424,8 @@ public class Client {
try{
if (wait) {
SerializedJobExecutionResult result = JobClient.submitJobAndWait(actorSystem,
jobManagerGateway, jobGraph, timeout, printStatusDuringExecution);
try {
return result.toJobExecutionResult(this.userCodeClassLoader);
}
catch (Exception e) {
throw new ProgramInvocationException(
"Failed to deserialize the accumulator result after the job execution", e);
}
return JobClient.submitJobAndWait(actorSystem,
jobManagerGateway, jobGraph, timeout, printStatusDuringExecution, this.userCodeClassLoader);
}
else {
JobClient.submitJobDetached(jobManagerGateway, jobGraph, timeout);
......
......@@ -46,7 +46,7 @@ public class InstantiationUtil {
* user-code ClassLoader.
*
*/
private static class ClassLoaderObjectInputStream extends ObjectInputStream {
public static class ClassLoaderObjectInputStream extends ObjectInputStream {
private ClassLoader classLoader;
private static final HashMap<String, Class<?>> primitiveClasses
......
......@@ -23,10 +23,10 @@ import akka.actor.ActorSystem;
import akka.actor.Address;
import akka.actor.PoisonPill;
import akka.actor.Props;
import akka.actor.Status;
import akka.pattern.Patterns;
import akka.util.Timeout;
import com.google.common.base.Preconditions;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
......@@ -36,6 +36,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.messages.JobClientMessages;
import org.apache.flink.runtime.messages.JobManagerMessages;
import org.apache.flink.runtime.util.SerializedThrowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -64,8 +65,7 @@ public class JobClient {
public static ActorSystem startJobClientActorSystem(Configuration config)
throws IOException {
LOG.info("Starting JobClient actor system");
Option<Tuple2<String, Object>> remoting =
new Some<Tuple2<String, Object>>(new Tuple2<String, Object>("", 0));
Option<Tuple2<String, Object>> remoting = new Some<>(new Tuple2<String, Object>("", 0));
// start a remote actor system to listen on an arbitrary port
ActorSystem system = AkkaUtils.createActorSystem(config, remoting);
......@@ -123,12 +123,13 @@ public class JobClient {
* @throws org.apache.flink.runtime.client.JobExecutionException Thrown if the job
* execution fails.
*/
public static SerializedJobExecutionResult submitJobAndWait(
public static JobExecutionResult submitJobAndWait(
ActorSystem actorSystem,
ActorGateway jobManagerGateway,
JobGraph jobGraph,
FiniteDuration timeout,
boolean sysoutLogUpdates) throws JobExecutionException {
boolean sysoutLogUpdates,
ClassLoader userCodeClassloader) throws JobExecutionException {
Preconditions.checkNotNull(actorSystem, "The actorSystem must not be null.");
Preconditions.checkNotNull(jobManagerGateway, "The jobManagerGateway must not be null.");
......@@ -160,26 +161,30 @@ public class JobClient {
SerializedJobExecutionResult result = ((JobManagerMessages.JobResultSuccess) answer).result();
if (result != null) {
return result;
return result.toJobExecutionResult(userCodeClassloader);
} else {
throw new Exception("Job was successfully executed but result contained a null JobExecutionResult.");
}
} else if (answer instanceof Status.Failure) {
throw ((Status.Failure) answer).cause();
} else {
throw new Exception("Unknown answer after submitting the job: " + answer);
}
}
catch (JobExecutionException e) {
throw e;
if(e.getCause() instanceof SerializedThrowable) {
SerializedThrowable serializedThrowable = (SerializedThrowable)e.getCause();
Throwable deserialized = serializedThrowable.deserializeError(userCodeClassloader);
throw new JobExecutionException(jobGraph.getJobID(), "Job execution failed " + deserialized.getMessage(), deserialized);
} else {
throw e;
}
}
catch (TimeoutException e) {
throw new JobTimeoutException(jobGraph.getJobID(), "Timeout while waiting for JobManager answer. " +
"Job time exceeded " + AkkaUtils.INF_TIMEOUT(), e);
}
catch (Throwable t) {
catch (Throwable throwable) {
throw new JobExecutionException(jobGraph.getJobID(),
"Communication with JobManager failed: " + t.getMessage(), t);
"Communication with JobManager failed: " + throwable.getMessage(), throwable);
}
finally {
// failsafe shutdown of the client actor
......
......@@ -42,6 +42,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.Scheduler;
import org.apache.flink.runtime.messages.ExecutionGraphMessages;
import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.util.SerializableObject;
import org.apache.flink.runtime.util.SerializedThrowable;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.ExceptionUtils;
......@@ -1028,8 +1029,12 @@ public class ExecutionGraph implements Serializable {
private void notifyJobStatusChange(JobStatus newState, Throwable error) {
if (jobStatusListenerActors.size() > 0) {
SerializedThrowable serializedThrowable = null;
if(error != null) {
serializedThrowable = new SerializedThrowable(error);
}
ExecutionGraphMessages.JobStatusChanged message =
new ExecutionGraphMessages.JobStatusChanged(jobID, newState, System.currentTimeMillis(), error);
new ExecutionGraphMessages.JobStatusChanged(jobID, newState, System.currentTimeMillis(), serializedThrowable);
for (ActorGateway listener: jobStatusListenerActors) {
listener.tell(message);
......
......@@ -33,6 +33,6 @@ public interface OperatorStateCarrier<T extends StateHandle<?>> {
*
* @param stateHandle The handle to the state.
*/
public void setInitialState(T stateHandle) throws Exception;
void setInitialState(T stateHandle) throws Exception;
}
......@@ -18,6 +18,8 @@
package org.apache.flink.runtime.state;
import org.apache.flink.util.InstantiationUtil;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
......@@ -56,9 +58,9 @@ public abstract class ByteStreamStateHandle implements StateHandle<Serializable>
protected abstract InputStream getInputStream() throws Exception;
@Override
public Serializable getState() throws Exception {
public Serializable getState(ClassLoader userCodeClassLoader) throws Exception {
if (!stateFetched()) {
ObjectInputStream stream = new ObjectInputStream(getInputStream());
ObjectInputStream stream = new InstantiationUtil.ClassLoaderObjectInputStream(getInputStream(), userCodeClassLoader);
try {
state = (Serializable) stream.readObject();
} finally {
......
......@@ -34,7 +34,8 @@ public class LocalStateHandle<T extends Serializable> implements StateHandle<T>
}
@Override
public T getState() {
public T getState(ClassLoader userCodeClassLoader) {
// The object has been deserialized correctly before
return state;
}
......
......@@ -24,7 +24,8 @@ import java.util.Map;
/**
* Wrapper for storing the handles for each state in a partitioned form. It can
* be used to repartition the state before re-injecting to the tasks.
*
*
* TODO: This class needs testing!
*/
public class PartitionedStateHandle implements
StateHandle<Map<Serializable, StateHandle<Serializable>>> {
......@@ -38,7 +39,7 @@ public class PartitionedStateHandle implements
}
@Override
public Map<Serializable, StateHandle<Serializable>> getState() throws Exception {
public Map<Serializable, StateHandle<Serializable>> getState(ClassLoader userCodeClassLoader) throws Exception {
return handles;
}
......
......@@ -28,12 +28,14 @@ import java.io.Serializable;
public interface StateHandle<T> extends Serializable {
/**
* This retrieves and return the state represented by the handle.
*
* This retrieves and return the state represented by the handle.
*
* @param userCodeClassLoader Class loader for deserializing user code specific classes
*
* @return The state represented by the handle.
* @throws java.lang.Exception Thrown, if the state cannot be fetched.
*/
T getState() throws Exception;
T getState(ClassLoader userCodeClassLoader) throws Exception;
/**
* Discards the state referred to by this handle, to free up resources in
......
......@@ -18,14 +18,11 @@
package org.apache.flink.runtime.taskmanager;
import java.util.Arrays;
import org.apache.flink.runtime.accumulators.AccumulatorSnapshot;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.api.common.JobID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.runtime.util.SerializedThrowable;
/**
* This class represents an update about a task's execution state.
......@@ -47,11 +44,7 @@ public class TaskExecutionState implements java.io.Serializable {
private final ExecutionState executionState;
private final byte[] serializedError;
// The exception must not be (de)serialized with the class, as its
// class may not be part of the system class loader.
private transient Throwable cachedError;
private final SerializedThrowable throwable;
/** Serialized flink and user-defined accumulators */
private final AccumulatorSnapshot accumulators;
......@@ -104,49 +97,19 @@ public class TaskExecutionState implements java.io.Serializable {
ExecutionState executionState, Throwable error,
AccumulatorSnapshot accumulators) {
if (jobID == null || executionId == null || executionState == null) {
if (jobID == null || executionId == null || executionState == null) {
throw new NullPointerException();
}
this.jobID = jobID;
this.executionId = executionId;
this.executionState = executionState;
this.cachedError = error;
this.accumulators = accumulators;
if (error != null) {
byte[] serializedError;
try {
serializedError = InstantiationUtil.serializeObject(error);
}
catch (Throwable t) {
// could not serialize exception. send the stringified version instead
try {
this.cachedError = new Exception(ExceptionUtils.stringifyException(error));
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable tt) {
// seems like we cannot do much to report the actual exception
// report a placeholder instead
try {
this.cachedError = new Exception("Cause is a '" + error.getClass().getName()
+ "' (failed to serialize or stringify)");
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable ttt) {
// this should never happen unless the JVM is fubar.
// we just report the state without the error
this.cachedError = null;
serializedError = null;
}
}
}
this.serializedError = serializedError;
}
else {
this.serializedError = null;
if(error != null) {
this.throwable = new SerializedThrowable(error);
} else {
this.throwable = null;
}
this.accumulators = accumulators;
}
// --------------------------------------------------------------------------------------------
......@@ -160,19 +123,11 @@ public class TaskExecutionState implements java.io.Serializable {
* job this update refers to.
*/
public Throwable getError(ClassLoader usercodeClassloader) {
if (this.serializedError == null) {
if (this.throwable == null) {
return null;
} else {
return throwable.deserializeError(usercodeClassloader);
}
if (this.cachedError == null) {
try {
cachedError = (Throwable) InstantiationUtil.deserializeObject(this.serializedError, usercodeClassloader);
}
catch (Exception e) {
throw new RuntimeException("Error while deserializing the attached exception", e);
}
}
return this.cachedError;
}
/**
......@@ -218,8 +173,8 @@ public class TaskExecutionState implements java.io.Serializable {
return other.jobID.equals(this.jobID) &&
other.executionId.equals(this.executionId) &&
other.executionState == this.executionState &&
(other.serializedError == null ? this.serializedError == null :
(this.serializedError != null && Arrays.equals(this.serializedError, other.serializedError)));
(other.throwable == null ? this.throwable == null :
(this.throwable != null && throwable.equals(other.throwable) ));
}
else {
return false;
......@@ -235,7 +190,6 @@ public class TaskExecutionState implements java.io.Serializable {
public String toString() {
return String.format("TaskState jobId=%s, executionId=%s, state=%s, error=%s",
jobID, executionId, executionState,
cachedError == null ? (serializedError == null ? "(null)" : "(serialized)")
: (cachedError.getClass().getName() + ": " + cachedError.getMessage()));
throwable == null ? "(null)" : throwable.toString());
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.util;
import com.google.common.base.Preconditions;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import java.io.Serializable;
import java.util.Arrays;
/**
* Utility class for dealing with serialized Throwables.
* Needed to send around user-specific exception classes with Akka.
*/
public class SerializedThrowable extends Exception implements Serializable {
private static final long serialVersionUID = 7284183123441947635L;
private final byte[] serializedError;
// The exception must not be (de)serialized with the class, as its
// class may not be part of the system class loader.
private transient Throwable cachedError;
/**
* Create a new SerializedThrowable.
* @param error The exception to serialize.
*/
public SerializedThrowable(Throwable error) {
Preconditions.checkNotNull(error, "The exception to serialize has to be set");
this.cachedError = error;
byte[] serializedError;
try {
serializedError = InstantiationUtil.serializeObject(error);
}
catch (Throwable t) {
// could not serialize exception. send the stringified version instead
try {
this.cachedError = new Exception(ExceptionUtils.stringifyException(error));
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable tt) {
// seems like we cannot do much to report the actual exception
// report a placeholder instead
try {
this.cachedError = new Exception("Cause is a '" + error.getClass().getName()
+ "' (failed to serialize or stringify)");
serializedError = InstantiationUtil.serializeObject(this.cachedError);
}
catch (Throwable ttt) {
// this should never happen unless the JVM is fubar.
// we just report the state without the error
this.cachedError = null;
serializedError = null;
}
}
}
this.serializedError = serializedError;
}
public Throwable deserializeError(ClassLoader userCodeClassloader) {
if (this.cachedError == null) {
try {
cachedError = (Throwable) InstantiationUtil.deserializeObject(this.serializedError, userCodeClassloader);
}
catch (Exception e) {
throw new RuntimeException("Error while deserializing the attached exception", e);
}
}
return this.cachedError;
}
@Override
public boolean equals(Object obj) {
if(obj instanceof SerializedThrowable) {
return Arrays.equals(this.serializedError, ((SerializedThrowable)obj).serializedError);
}
return false;
}
@Override
public String toString() {
if(cachedError != null) {
return cachedError.getClass().getName() + ": " + cachedError.getMessage();
}
if(serializedError == null) {
return "(null)"; // can not happen as per Ctor check.
} else {
return "(serialized)";
}
}
public static Throwable get(Throwable serThrowable, ClassLoader loader) {
if(serThrowable instanceof SerializedThrowable) {
return ((SerializedThrowable)serThrowable).deserializeError(loader);
} else {
return serThrowable;
}
}
}
......@@ -47,11 +47,10 @@ import org.apache.flink.runtime.process.ProcessReaper
import org.apache.flink.runtime.security.SecurityUtils
import org.apache.flink.runtime.security.SecurityUtils.FlinkSecuredRunner
import org.apache.flink.runtime.taskmanager.TaskManager
import org.apache.flink.runtime.util.ZooKeeperUtil
import org.apache.flink.runtime.util.EnvironmentInformation
import org.apache.flink.runtime.util.{SerializedThrowable, ZooKeeperUtil, EnvironmentInformation}
import org.apache.flink.runtime.webmonitor.WebMonitor
import org.apache.flink.runtime.{FlinkActor, StreamingMode, LeaderSessionMessages}
import org.apache.flink.runtime.{LogMessages}
import org.apache.flink.runtime.LogMessages
import org.apache.flink.runtime.akka.AkkaUtils
import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager
import org.apache.flink.runtime.instance.{ActorGateway, AkkaActorGateway, InstanceManager}
......@@ -327,8 +326,12 @@ class JobManager(
currentJobs.get(jobID) match {
case Some((executionGraph, jobInfo)) => executionGraph.getJobName
log.info(s"Status of job $jobID (${executionGraph.getJobName}) changed to $newJobStatus.",
error)
val deserializedError = if(error != null) {
error.deserializeError(executionGraph.getUserClassLoader)
} else null
log.info(
s"Status of job $jobID (${executionGraph.getJobName}) changed to $newJobStatus.",
deserializedError)
if (newJobStatus.isTerminalState) {
jobInfo.end = timeStamp
......@@ -343,8 +346,10 @@ class JobManager(
log.error(s"Cannot fetch serialized accumulators for job $jobID", e)
Collections.emptyMap()
}
val result = new SerializedJobExecutionResult(jobID, jobInfo.duration,
accumulatorResults)
val result = new SerializedJobExecutionResult(
jobID,
jobInfo.duration,
accumulatorResults)
jobInfo.client ! decorateMessage(JobResultSuccess(result))
case JobStatus.CANCELED =>
......@@ -352,9 +357,8 @@ class JobManager(
Failure(
new JobCancellationException(
jobID,
"Job was cancelled.", error)
)
)
"Job was cancelled.",
new SerializedThrowable(deserializedError))))
case JobStatus.FAILED =>
jobInfo.client ! decorateMessage(
......@@ -362,14 +366,11 @@ class JobManager(
new JobExecutionException(
jobID,
"Job execution failed.",
error)
)
)
new SerializedThrowable(deserializedError))))
case x =>
val exception = new JobExecutionException(jobID, s"$x is not a " +
"terminal state.")
jobInfo.client ! decorateMessage(Failure(exception))
val exception = new JobExecutionException(jobID, s"$x is not a terminal state.")
jobInfo.client ! decorateMessage(Failure(new SerializedThrowable(exception)))
throw exception
}
......
......@@ -25,6 +25,7 @@ import org.apache.flink.api.common.JobID
import org.apache.flink.runtime.execution.ExecutionState
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID
import org.apache.flink.runtime.jobgraph.{JobStatus, JobVertexID}
import org.apache.flink.runtime.util.SerializedThrowable
/**
* This object contains the execution graph specific messages.
......@@ -74,13 +75,13 @@ object ExecutionGraphMessages {
* @param jobID identifying the corresponding job
* @param newJobStatus
* @param timestamp
* @param error
* @param serializedError
*/
case class JobStatusChanged(
jobID: JobID,
newJobStatus: JobStatus,
timestamp: Long,
error: Throwable)
serializedError: SerializedThrowable)
extends RequiresLeaderSessionID {
override def toString: String = {
s"${timestampToString(timestamp)}\tJob execution switched to status $newJobStatus."
......
......@@ -19,6 +19,7 @@
package org.apache.flink.runtime.messages
import org.apache.flink.runtime.jobgraph.JobGraph
import org.apache.flink.runtime.util.SerializedThrowable
/**
* This object contains the [[org.apache.flink.runtime.client.JobClient]] specific messages
......
......@@ -24,7 +24,7 @@ import akka.pattern.Patterns.gracefulStop
import akka.pattern.ask
import akka.actor.{ActorRef, ActorSystem}
import com.typesafe.config.Config
import org.apache.flink.api.common.JobSubmissionResult
import org.apache.flink.api.common.{JobExecutionResult, JobSubmissionResult}
import org.apache.flink.configuration.{ConfigConstants, Configuration}
import org.apache.flink.runtime.StreamingMode
import org.apache.flink.runtime.akka.AkkaUtils
......@@ -238,9 +238,7 @@ abstract class FlinkMiniCluster(
}
@throws(classOf[JobExecutionException])
def submitJobAndWait(jobGraph: JobGraph, printUpdates: Boolean)
: SerializedJobExecutionResult = {
def submitJobAndWait(jobGraph: JobGraph, printUpdates: Boolean): JobExecutionResult = {
submitJobAndWait(jobGraph, printUpdates, timeout)
}
......@@ -249,7 +247,7 @@ abstract class FlinkMiniCluster(
jobGraph: JobGraph,
printUpdates: Boolean,
timeout: FiniteDuration)
: SerializedJobExecutionResult = {
: JobExecutionResult = {
val clientActorSystem = if (singleActorSystem) jobManagerActorSystem
else JobClient.startJobClientActorSystem(configuration)
......@@ -259,7 +257,8 @@ abstract class FlinkMiniCluster(
getJobManagerGateway(),
jobGraph,
timeout,
printUpdates)
printUpdates,
this.getClass.getClassLoader)
}
@throws(classOf[JobExecutionException])
......
......@@ -105,7 +105,7 @@ public class PartialConsumePipelinedResultTest {
flink.getJobManagerGateway(),
jobGraph,
TestingUtils.TESTING_DURATION(),
false);
false, this.getClass().getClassLoader());
}
// ---------------------------------------------------------------------------------------------
......
......@@ -83,7 +83,7 @@ public class CheckpointMessagesTest {
private static final long serialVersionUID = 8128146204128728332L;
@Override
public Serializable getState() {
public Serializable getState(ClassLoader userCodeClassLoader) {
return null;
}
......
......@@ -36,6 +36,7 @@ public class ByteStreamStateHandleTest {
@Test
public void testHandle() throws Exception {
final ClassLoader cl = this.getClass().getClassLoader();
MockHandle handle;
try {
......@@ -47,14 +48,14 @@ public class ByteStreamStateHandleTest {
handle = new MockHandle(1);
assertEquals(1, handle.getState());
assertEquals(1, handle.getState(cl));
assertTrue(handle.stateFetched());
assertFalse(handle.isWritten());
assertFalse(handle.discarded);
MockHandle handleDs = serializeDeserialize(handle);
assertEquals(1, handle.getState());
assertEquals(1, handle.getState(cl));
assertTrue(handle.stateFetched());
assertTrue(handle.isWritten());
assertTrue(handle.generatedOutput);
......@@ -66,7 +67,7 @@ public class ByteStreamStateHandleTest {
assertFalse(handle.discarded);
try {
handleDs.getState();
handleDs.getState(cl);
fail();
} catch (UnsupportedOperationException e) {
// good
......
......@@ -28,6 +28,7 @@ import org.apache.flink.runtime.jobgraph.{JobVertex, DistributionPattern, JobGra
import org.apache.flink.runtime.messages.JobManagerMessages._
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobRemoved
import org.apache.flink.runtime.testingUtils.{ScalaTestingUtils, TestingUtils}
import org.apache.flink.runtime.util.SerializedThrowable
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
......@@ -84,11 +85,13 @@ class JobManagerITCase(_system: ActorSystem)
within(2 second) {
val response = expectMsgType[Failure]
val exception = response.cause
val exception = SerializedThrowable.get(response.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
new NoResourceAvailableException(1,1,0) should equal(e.getCause)
val cause = e.getCause.asInstanceOf[SerializedThrowable].deserializeError(
this.getClass.getClassLoader)
new NoResourceAvailableException(1,1,0) should equal(cause)
case e => fail(s"Received wrong exception of type $e.")
}
}
......@@ -261,8 +264,9 @@ class JobManagerITCase(_system: ActorSystem)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
failure.cause match {
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -388,8 +392,8 @@ class JobManagerITCase(_system: ActorSystem)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -434,9 +438,10 @@ class JobManagerITCase(_system: ActorSystem)
within(TestingUtils.TESTING_DURATION) {
jmGateway.tell(SubmitJob(jobGraph, false), self)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
failure.cause match {
val failure = expectMsgType[Failure]
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -473,9 +478,10 @@ class JobManagerITCase(_system: ActorSystem)
within(TestingUtils.TESTING_DURATION) {
jmGateway.tell(SubmitJob(jobGraph, false), self)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
failure.cause match {
val failure = expectMsgType[Failure]
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -515,9 +521,10 @@ class JobManagerITCase(_system: ActorSystem)
jmGateway.tell(SubmitJob(jobGraph, false), self)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
failure.cause match {
val failure = expectMsgType[Failure]
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -562,9 +569,10 @@ class JobManagerITCase(_system: ActorSystem)
jmGateway.tell(SubmitJob(jobGraph, false), self)
expectMsg(Success(jobGraph.getJobID))
val failure = expectMsgType[Failure]
failure.cause match {
val failure = expectMsgType[Failure]
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......
......@@ -28,6 +28,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup
import org.apache.flink.runtime.messages.JobManagerMessages.SubmitJob
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages._
import org.apache.flink.runtime.testingUtils.{ScalaTestingUtils, TestingUtils}
import org.apache.flink.runtime.util.SerializedThrowable
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
......@@ -85,8 +86,8 @@ class TaskManagerFailsWithSlotSharingITCase(_system: ActorSystem)
taskManagers(0) ! PoisonPill
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
case e => fail(s"Received wrong exception $e.")
......@@ -133,8 +134,8 @@ class TaskManagerFailsWithSlotSharingITCase(_system: ActorSystem)
taskManagers(0) ! Kill
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......
......@@ -114,7 +114,7 @@ public class FileStateHandleTest {
assertFalse(deserializedHandle.stateFetched());
// Fetch the and compare with original
assertEquals(state, deserializedHandle.getState());
assertEquals(state, deserializedHandle.getState(this.getClass().getClassLoader()));
// Test whether discard removes the checkpoint file properly
assertTrue(hdfs.listFiles(hdPath, true).hasNext());
......
......@@ -78,9 +78,10 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> snapshots) throws Exception {
// Restore state using the Checkpointed interface
if (userFunction instanceof Checkpointed && snapshots.f0 != null) {
((Checkpointed) userFunction).restoreState(snapshots.f0.getState());
((Checkpointed) userFunction).restoreState(snapshots.f0.getState(runtimeContext.getUserCodeClassLoader()));
}
if (snapshots.f1 != null) {
......@@ -88,7 +89,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
for (Entry<String, OperatorStateHandle> snapshot : snapshots.f1.entrySet()) {
StreamOperatorState restoredOpState = runtimeContext.getState(snapshot.getKey(), snapshot.getValue().isPartitioned());
StateHandle<Serializable> checkpointHandle = snapshot.getValue();
restoredOpState.restoreState(checkpointHandle);
restoredOpState.restoreState(checkpointHandle, runtimeContext.getUserCodeClassLoader());
}
}
......
......@@ -69,10 +69,10 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
}
@Override
public void restoreStates(StateHandle<Serializable> snapshot) throws Exception {
public void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {
@SuppressWarnings("unchecked")
Map<Serializable, C> checkpoints = (Map<Serializable, C>) snapshot.getState();
Map<Serializable, C> checkpoints = (Map<Serializable, C>) snapshot.getState(userCodeClassLoader);
// we map the values back to the state from the checkpoints
for (Entry<Serializable, C> snapshotEntry : checkpoints.entrySet()) {
......
......@@ -38,8 +38,8 @@ public class OperatorStateHandle implements StateHandle<Serializable> {
}
@Override
public Serializable getState() throws Exception {
return handle.getState();
public Serializable getState(ClassLoader userCodeClassLoader) throws Exception {
return handle.getState(userCodeClassLoader);
}
@Override
......
......@@ -43,7 +43,7 @@ public interface PartitionedStateStore<S, C extends Serializable> {
StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws Exception;
void restoreStates(StateHandle<Serializable> snapshot) throws Exception;
void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception;
boolean containsKey(Serializable key);
......
......@@ -128,8 +128,8 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
}
@Override
public void restoreState(StateHandle<Serializable> snapshots) throws Exception {
stateStore.restoreStates(snapshots);
public void restoreState(StateHandle<Serializable> snapshots, ClassLoader userCodeClassLoader) throws Exception {
stateStore.restoreStates(snapshots, userCodeClassLoader);
}
@Override
......
......@@ -96,8 +96,8 @@ public class StreamOperatorState<S, C extends Serializable> implements OperatorS
}
@SuppressWarnings("unchecked")
public void restoreState(StateHandle<Serializable> snapshot) throws Exception {
update((S) checkpointer.restoreState((C) snapshot.getState()));
public void restoreState(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {
update(checkpointer.restoreState((C) snapshot.getState(userCodeClassLoader)));
}
public Map<Serializable, S> getPartitionedState() throws Exception {
......
......@@ -42,7 +42,8 @@ public class WrapperStateHandle extends LocalStateHandle<Serializable> {
@Override
public void discardState() throws Exception {
@SuppressWarnings("unchecked")
List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates = (List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) getState();
List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates =
(List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) getState(null); // we can pass "null" here because the LocalStateHandle is not using the ClassLoader anyways
for (Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> state : chainedStates) {
if (state != null) {
if (state.f0 != null) {
......
......@@ -29,7 +29,7 @@ import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.GlobalConfiguration;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
......@@ -106,7 +106,7 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
public StreamTask() {
checkpointBarrierListener = new CheckpointBarrierListener();
contexts = new ArrayList<StreamingRuntimeContext>();
contexts = new ArrayList<>();
}
// ------------------------------------------------------------------------
......@@ -271,7 +271,7 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
// We retrieve end restore the states for the chained operators.
List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates =
(List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) stateHandle.getState();
(List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) stateHandle.getState(this.userClassLoader);
// We restore all stateful operators
for (int i = 0; i < chainedStates.size(); i++) {
......@@ -358,7 +358,8 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
// If the user did not specify a provider in the program we try to get it from the config
if (provider == null) {
String backendName = GlobalConfiguration.getString(ConfigConstants.STATE_BACKEND,
Configuration flinkConfig = getEnvironment().getTaskManagerInfo().getConfiguration();
String backendName = flinkConfig.getString(ConfigConstants.STATE_BACKEND,
ConfigConstants.DEFAULT_STATE_BACKEND).toUpperCase();
StateBackend backend;
......@@ -372,9 +373,9 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
switch (backend) {
case JOBMANAGER:
LOG.info("State backend for state checkpoints is set to jobmanager.");
return new LocalStateHandle.LocalStateHandleProvider<Serializable>();
return new LocalStateHandle.LocalStateHandleProvider<>();
case FILESYSTEM:
String checkpointDir = GlobalConfiguration.getString(ConfigConstants.STATE_BACKEND_FS_DIR, null);
String checkpointDir = flinkConfig.getString(ConfigConstants.STATE_BACKEND_FS_DIR, null);
if (checkpointDir != null) {
LOG.info("State backend for state checkpoints is set to filesystem with directory: "
+ checkpointDir);
......
......@@ -45,7 +45,7 @@ public class StateHandleTest {
MockHandle<Serializable> h1 = new MockHandle<Serializable>(1);
OperatorStateHandle opHandle = new OperatorStateHandle(h1, true);
assertEquals(1, opHandle.getState());
assertEquals(1, opHandle.getState(this.getClass().getClassLoader()));
OperatorStateHandle dsHandle = serializeDeserialize(opHandle);
MockHandle<Serializable> h2 = (MockHandle<Serializable>) dsHandle.getHandle();
......@@ -60,6 +60,7 @@ public class StateHandleTest {
@Test
public void wrapperStateHandleTest() throws Exception {
final ClassLoader cl = this.getClass().getClassLoader();
MockHandle<Serializable> h1 = new MockHandle<Serializable>(1);
MockHandle<Serializable> h2 = new MockHandle<Serializable>(2);
......@@ -82,16 +83,16 @@ public class StateHandleTest {
@SuppressWarnings("unchecked")
Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> dsFullState = ((List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) dsWrapper
.getState()).get(0);
.getState(cl)).get(0);
Map<String, OperatorStateHandle> dsOpHandles = dsFullState.f1;
assertNull(dsFullState.f0.getState());
assertNull(dsFullState.f0.getState(cl));
assertFalse(((MockHandle<?>) dsFullState.f0).discarded);
assertFalse(((MockHandle<?>) dsOpHandles.get("h1").getHandle()).discarded);
assertNull(dsOpHandles.get("h1").getState());
assertNull(dsOpHandles.get("h1").getState(cl));
assertFalse(((MockHandle<?>) dsOpHandles.get("h2").getHandle()).discarded);
assertNull(dsOpHandles.get("h2").getState());
assertNull(dsOpHandles.get("h2").getState(cl));
dsWrapper.discardState();
......@@ -126,7 +127,7 @@ public class StateHandleTest {
}
@Override
public T getState() {
public T getState(ClassLoader userCodeClassLoader) {
return state;
}
}
......
......@@ -52,7 +52,6 @@ import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
import org.apache.flink.streaming.util.TestStreamEnvironment;
import org.apache.flink.util.InstantiationUtil;
import org.junit.Test;
......@@ -170,9 +169,9 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
}, context);
if (serializedState != null) {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
op.restoreInitialState((Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>) InstantiationUtil
.deserializeObject(serializedState, Thread.currentThread()
.getContextClassLoader()));
.deserializeObject(serializedState, cl));
}
op.open(null);
......
......@@ -73,8 +73,7 @@ public class TestStreamEnvironment extends StreamExecutionEnvironment {
}
try {
sync = true;
SerializedJobExecutionResult result = executor.submitJobAndWait(jobGraph, false);
latestResult = result.toJobExecutionResult(getClass().getClassLoader());
latestResult = executor.submitJobAndWait(jobGraph, false);
return latestResult;
} catch (JobExecutionException e) {
if (e.getMessage().contains("GraphConversionException")) {
......@@ -116,8 +115,7 @@ public class TestStreamEnvironment extends StreamExecutionEnvironment {
jobRunner = new Thread() {
public void run() {
try {
SerializedJobExecutionResult result = cluster.submitJobAndWait(jobGraph, false);
latestResult = result.toJobExecutionResult(getClass().getClassLoader());
latestResult = cluster.submitJobAndWait(jobGraph, false);
} catch (JobExecutionException e) {
// TODO remove: hack to make ITCase succeed because .submitJobAndWait() throws exception on .stop() (see this.shutdown())
latestResult = new JobExecutionResult(null, 0, null);
......
......@@ -27,7 +27,6 @@ import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plandump.PlanJSONDumpGenerator;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.junit.Assert;
......@@ -121,8 +120,7 @@ public abstract class RecordAPITestBase extends AbstractTestBase {
Assert.assertNotNull("Obtained null JobGraph", jobGraph);
try {
SerializedJobExecutionResult result = executor.submitJobAndWait(jobGraph, false);
this.jobExecutionResult = result.toJobExecutionResult(getClass().getClassLoader());
this.jobExecutionResult = executor.submitJobAndWait(jobGraph, false);
}
catch (Exception e) {
System.err.println(e.getMessage());
......
......@@ -28,7 +28,6 @@ import org.apache.flink.optimizer.Optimizer;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plandump.PlanJSONDumpGenerator;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.junit.Assert;
......@@ -51,9 +50,7 @@ public class TestEnvironment extends ExecutionEnvironment {
JobGraphGenerator jgg = new JobGraphGenerator();
JobGraph jobGraph = jgg.compileJobGraph(op);
SerializedJobExecutionResult result = executor.submitJobAndWait(jobGraph, false);
this.lastJobExecutionResult = result.toJobExecutionResult(getClass().getClassLoader());
this.lastJobExecutionResult = executor.submitJobAndWait(jobGraph, false);
return this.lastJobExecutionResult;
}
catch (Exception e) {
......
......@@ -387,6 +387,25 @@ under the License.
</descriptors>
</configuration>
</execution>
<execution>
<id>create-streaming-state-checkpointed-classloader-jar</id>
<phase>process-test-classes</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<archive>
<manifest>
<mainClass>org.apache.flink.test.classloading.jar.CheckpointedStreamingProgram</mainClass>
</manifest>
</archive>
<finalName>streaming-checkpointed-classloader</finalName>
<attach>false</attach>
<descriptors>
<descriptor>src/test/assembly/test-streaming-state-checkpointed-classloader-assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
......
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<assembly>
<id>test-jar</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<fileSets>
<fileSet>
<directory>${project.build.testOutputDirectory}</directory>
<outputDirectory>/</outputDirectory>
<!--modify/add include to match your package(s) -->
<includes>
<include>org/apache/flink/test/classloading/jar/CheckpointedStreamingProgram.class</include>
<include>org/apache/flink/test/classloading/jar/CheckpointedStreamingProgram$*.class</include>
<include>org/apache/flink/test/testdata/WordCountData.class</include>
</includes>
</fileSet>
</fileSets>
</assembly>
\ No newline at end of file
......@@ -27,7 +27,9 @@ import org.apache.flink.test.testdata.KMeansData;
import org.apache.flink.test.util.ForkableFlinkMiniCluster;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
public class ClassLoaderITCase {
......@@ -35,14 +37,25 @@ public class ClassLoaderITCase {
private static final String STREAMING_PROG_JAR_FILE = "target/streamingclassloader-test-jar.jar";
private static final String STREAMING_CHECKPOINTED_PROG_JAR_FILE = "target/streaming-checkpointed-classloader-test-jar.jar";
private static final String KMEANS_JAR_PATH = "target/kmeans-test-jar.jar";
@Rule
public TemporaryFolder folder = new TemporaryFolder();
@Test
public void testJobWithCustomInputFormat() {
public void testJobsWithCustomClassLoader() {
try {
Configuration config = new Configuration();
config.setInteger(ConfigConstants.LOCAL_INSTANCE_MANAGER_NUMBER_TASK_MANAGER, 2);
config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 2);
config.setString(ConfigConstants.DEFAULT_EXECUTION_RETRY_DELAY_KEY, "0 s");
// we need to use the "filesystem" state backend to ensure FLINK-2543 is not happening again.
config.setString(ConfigConstants.STATE_BACKEND, "filesystem");
config.setString(ConfigConstants.STATE_BACKEND_FS_DIR, "file://" + folder.newFolder().getAbsolutePath());
ForkableFlinkMiniCluster testCluster = new ForkableFlinkMiniCluster(config, false);
......@@ -57,10 +70,28 @@ public class ClassLoaderITCase {
} );
inputSplitTestProg.invokeInteractiveModeForExecution();
// regular streaming job
PackagedProgram streamingProg = new PackagedProgram(new File(STREAMING_PROG_JAR_FILE),
new String[] { STREAMING_PROG_JAR_FILE, "localhost", String.valueOf(port) } );
streamingProg.invokeInteractiveModeForExecution();
// checkpointed streaming job with custom classes for the checkpoint (FLINK-2543)
// the test also ensures that user specific exceptions are serializable between JobManager <--> JobClient.
try {
PackagedProgram streamingCheckpointedProg = new PackagedProgram(new File(STREAMING_CHECKPOINTED_PROG_JAR_FILE),
new String[]{STREAMING_CHECKPOINTED_PROG_JAR_FILE, "localhost", String.valueOf(port)});
streamingCheckpointedProg.invokeInteractiveModeForExecution();
} catch(Exception e) {
// we can not access the SuccessException here when executing the tests with maven, because its not available in the jar.
try {
if (!(e.getCause().getCause().getClass().getCanonicalName().equals("org.apache.flink.test.classloading.jar.CheckpointedStreamingProgram.SuccessException"))) {
throw e;
}
} catch(Throwable ignore) {
throw e;
}
}
PackagedProgram kMeansProg = new PackagedProgram(new File(KMEANS_JAR_PATH),
new String[] { KMEANS_JAR_PATH,
"localhost",
......@@ -81,4 +112,5 @@ public class ClassLoaderITCase {
Assert.fail(e.getMessage());
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.classloading.jar;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import java.lang.RuntimeException;
import java.net.URL;
import java.net.URLClassLoader;
/**
* A simple streaming program, which is using the state checkpointing of Flink.
* It is using a user defined class as the state.
*/
@SuppressWarnings("serial")
public class CheckpointedStreamingProgram {
private static final int CHECKPOINT_INTERVALL = 100;
public static void main(String[] args) throws Exception {
ClassLoader cl = ClassLoader.getSystemClassLoader();
URL[] urls = ((URLClassLoader)cl).getURLs();
for(URL url: urls){
System.out.println(url.getFile());
}
System.out.println("CheckpointedStreamingProgram classpath: ");
final String jarFile = args[0];
final String host = args[1];
final int port = Integer.parseInt(args[2]);
StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment(host, port, jarFile);
env.getConfig().disableSysoutLogging();
env.enableCheckpointing(CHECKPOINT_INTERVALL);
env.setNumberOfExecutionRetries(1);
env.disableOperatorChaining();
DataStream<String> text = env.addSource(new SimpleStringGenerator());
text.map(new StatefulMapper()).addSink(new NoOpSink());
env.setParallelism(1);
env.execute("Checkpointed Streaming Program");
}
// with Checkpoining
public static class SimpleStringGenerator implements SourceFunction<String>, Checkpointed<Integer> {
public boolean running = true;
@Override
public void run(SourceContext<String> ctx) throws Exception {
while(running) {
Thread.sleep(1);
ctx.collect("someString");
}
}
@Override
public void cancel() {
running = false;
}
@Override
public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return null;
}
@Override
public void restoreState(Integer state) {
}
}
public static class StatefulMapper implements MapFunction<String, String>, Checkpointed<StatefulMapper>, CheckpointNotifier {
private String someState;
private boolean atLeastOneSnapshotComplete = false;
private boolean restored = false;
@Override
public StatefulMapper snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return this;
}
@Override
public void restoreState(StatefulMapper state) {
restored = true;
this.someState = state.someState;
this.atLeastOneSnapshotComplete = state.atLeastOneSnapshotComplete;
}
@Override
public String map(String value) throws Exception {
if(!atLeastOneSnapshotComplete) {
// throttle consumption by the checkpoint interval until we have one snapshot.
Thread.sleep(CHECKPOINT_INTERVALL);
}
if(atLeastOneSnapshotComplete && !restored) {
throw new RuntimeException("Intended failure, to trigger restore");
}
if(restored) {
throw new SuccessException();
//throw new RuntimeException("All good");
}
someState = value; // update our state
return value;
}
@Override
public void notifyCheckpointComplete(long checkpointId) throws Exception {
atLeastOneSnapshotComplete = true;
}
}
// --------------------------------------------------------------------------------------------
/**
* We intentionally use a user specified failure exception
*/
public static class SuccessException extends Exception {
}
public static class NoOpSink implements SinkFunction<String>{
@Override
public void invoke(String value) throws Exception {
}
}
}
......@@ -23,7 +23,6 @@ import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.client.JobSubmissionException;
import org.apache.flink.runtime.client.SerializedJobExecutionResult;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobmanager.Tasks;
......@@ -48,7 +47,7 @@ public class JobSubmissionFailsITCase {
private static final int NUM_SLOTS = 20;
private static ForkableFlinkMiniCluster cluser;
private static ForkableFlinkMiniCluster cluster;
private static JobGraph workingJobGraph;
@BeforeClass
......@@ -59,7 +58,7 @@ public class JobSubmissionFailsITCase {
config.setInteger(ConfigConstants.LOCAL_INSTANCE_MANAGER_NUMBER_TASK_MANAGER, 2);
config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, NUM_SLOTS / 2);
cluser = new ForkableFlinkMiniCluster(config);
cluster = new ForkableFlinkMiniCluster(config);
final JobVertex jobVertex = new JobVertex("Working job vertex.");
jobVertex.setInvokableClass(Tasks.NoOpInvokable.class);
......@@ -74,7 +73,7 @@ public class JobSubmissionFailsITCase {
@AfterClass
public static void teardown() {
try {
cluser.shutdown();
cluster.shutdown();
}
catch (Exception e) {
e.printStackTrace();
......@@ -100,13 +99,11 @@ public class JobSubmissionFailsITCase {
private JobExecutionResult submitJob(JobGraph jobGraph) throws Exception {
if (detached) {
cluser.submitJobDetached(jobGraph);
cluster.submitJobDetached(jobGraph);
return null;
}
else {
SerializedJobExecutionResult result = cluser.submitJobAndWait(
jobGraph, false, TestingUtils.TESTING_DURATION());
return result.toJobExecutionResult(getClass().getClassLoader());
return cluster.submitJobAndWait(jobGraph, false, TestingUtils.TESTING_DURATION());
}
}
......@@ -130,7 +127,7 @@ public class JobSubmissionFailsITCase {
fail("Caught wrong exception of type " + t.getClass() + ".");
}
cluser.submitJobAndWait(workingJobGraph, false);
cluster.submitJobAndWait(workingJobGraph, false);
}
catch (Exception e) {
e.printStackTrace();
......@@ -155,7 +152,7 @@ public class JobSubmissionFailsITCase {
fail("Caught wrong exception of type " + t.getClass() + ".");
}
cluser.submitJobAndWait(workingJobGraph, false);
cluster.submitJobAndWait(workingJobGraph, false);
}
catch (Exception e) {
e.printStackTrace();
......@@ -178,7 +175,7 @@ public class JobSubmissionFailsITCase {
fail("Caught wrong exception of type " + t.getClass() + ".");
}
cluser.submitJobAndWait(workingJobGraph, false);
cluster.submitJobAndWait(workingJobGraph, false);
}
catch (Exception e) {
e.printStackTrace();
......
......@@ -33,6 +33,7 @@ import org.apache.flink.runtime.messages.TaskManagerMessages.{RegisteredAtJobMan
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages._
import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect
import org.apache.flink.runtime.testingUtils.{ScalaTestingUtils, TestingUtils}
import org.apache.flink.runtime.util.SerializedThrowable
import org.apache.flink.test.util.ForkableFlinkMiniCluster
import org.junit.runner.RunWith
......@@ -126,8 +127,8 @@ class TaskManagerFailsITCase(_system: ActorSystem)
}
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -169,8 +170,8 @@ class TaskManagerFailsITCase(_system: ActorSystem)
taskManagers(0) ! Kill
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......@@ -208,8 +209,8 @@ class TaskManagerFailsITCase(_system: ActorSystem)
tm ! PoisonPill
val failure = expectMsgType[Failure]
failure.cause match {
val exception = SerializedThrowable.get(failure.cause, this.getClass.getClassLoader)
exception match {
case e: JobExecutionException =>
jobGraph.getJobID should equal(e.getJobID)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册