提交 8e613014 编写于 作者: S Stephan Ewen

[FLINK-1672] [runtime] Unify Task and RuntimeEnvironment into one class.

 - This simplifies and hardens the failure handling during task startup
 - Guarantees that no actor system threads are blocked by task bootstrap, or task canceling
 - Corrects some previously erroneous corner case state transitions
 - Adds simple and robust tests
上级 1d368a4b
......@@ -18,6 +18,7 @@
package org.apache.flink.runtime.execution;
import akka.actor.ActorRef;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
......@@ -159,4 +160,6 @@ public interface Environment {
InputGate[] getAllInputGates();
// this should go away
ActorRef getJobManager();
}
......@@ -35,10 +35,10 @@ package org.apache.flink.runtime.execution;
* ... -> FAILED
* </pre>
*
* It is possible to enter the {@code FAILED} state from any other state.
* <p>It is possible to enter the {@code FAILED} state from any other state.</p>
*
* The states {@code FINISHED}, {@code CANCELED}, and {@code FAILED} are
* considered terminal states.
* <p>The states {@code FINISHED}, {@code CANCELED}, and {@code FAILED} are
* considered terminal states.</p>
*/
public enum ExecutionState {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.execution;
import akka.actor.ActorRef;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.accumulators.AccumulatorEvent;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.runtime.messages.accumulators.ReportAccumulatorResult;
import org.apache.flink.runtime.taskmanager.Task;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.FutureTask;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.google.common.base.Preconditions.checkElementIndex;
import static com.google.common.base.Preconditions.checkNotNull;
public class RuntimeEnvironment implements Environment, Runnable {
private static final Logger LOG = LoggerFactory.getLogger(RuntimeEnvironment.class);
private static final ThreadGroup TASK_THREADS = new ThreadGroup("Task Threads");
/** The ActorRef to the job manager */
private final ActorRef jobManager;
/** The task that owns this environment */
private final Task owner;
/** The job configuration encapsulated in the environment object. */
private final Configuration jobConfiguration;
/** The task configuration encapsulated in the environment object. */
private final Configuration taskConfiguration;
/** ClassLoader for all user code classes */
private final ClassLoader userCodeClassLoader;
/** Instance of the class to be run in this environment. */
private final AbstractInvokable invokable;
/** The memory manager of the current environment (currently the one associated with the executing TaskManager). */
private final MemoryManager memoryManager;
/** The I/O manager of the current environment (currently the one associated with the executing TaskManager). */
private final IOManager ioManager;
/** The input split provider that can be queried for new input splits. */
private final InputSplitProvider inputSplitProvider;
/** The thread executing the task in the environment. */
private Thread executingThread;
private final BroadcastVariableManager broadcastVariableManager;
private final Map<String, FutureTask<Path>> cacheCopyTasks = new HashMap<String, FutureTask<Path>>();
private final AtomicBoolean canceled = new AtomicBoolean();
private final ResultPartition[] producedPartitions;
private final ResultPartitionWriter[] writers;
private final SingleInputGate[] inputGates;
private final Map<IntermediateDataSetID, SingleInputGate> inputGatesById = new HashMap<IntermediateDataSetID, SingleInputGate>();
public RuntimeEnvironment(
ActorRef jobManager, Task owner, TaskDeploymentDescriptor tdd, ClassLoader userCodeClassLoader,
MemoryManager memoryManager, IOManager ioManager, InputSplitProvider inputSplitProvider,
BroadcastVariableManager broadcastVariableManager, NetworkEnvironment networkEnvironment) throws Exception {
this.owner = checkNotNull(owner);
this.memoryManager = checkNotNull(memoryManager);
this.ioManager = checkNotNull(ioManager);
this.inputSplitProvider = checkNotNull(inputSplitProvider);
this.jobManager = checkNotNull(jobManager);
this.broadcastVariableManager = checkNotNull(broadcastVariableManager);
try {
// Produced intermediate result partitions
final List<ResultPartitionDeploymentDescriptor> partitions = tdd.getProducedPartitions();
this.producedPartitions = new ResultPartition[partitions.size()];
this.writers = new ResultPartitionWriter[partitions.size()];
for (int i = 0; i < this.producedPartitions.length; i++) {
ResultPartitionDeploymentDescriptor desc = partitions.get(i);
ResultPartitionID partitionId = new ResultPartitionID(desc.getPartitionId(), owner.getExecutionId());
this.producedPartitions[i] = new ResultPartition(
this,
owner.getJobID(),
partitionId,
desc.getPartitionType(),
desc.getNumberOfSubpartitions(),
networkEnvironment.getPartitionManager(),
networkEnvironment.getPartitionConsumableNotifier(),
ioManager,
networkEnvironment.getDefaultIOMode());
writers[i] = new ResultPartitionWriter(this.producedPartitions[i]);
}
// Consumed intermediate result partitions
final List<InputGateDeploymentDescriptor> consumedPartitions = tdd.getInputGates();
this.inputGates = new SingleInputGate[consumedPartitions.size()];
for (int i = 0; i < inputGates.length; i++) {
inputGates[i] = SingleInputGate.create(
this, consumedPartitions.get(i), networkEnvironment);
// The input gates are organized by key for task updates/channel updates at runtime
inputGatesById.put(inputGates[i].getConsumedResultId(), inputGates[i]);
}
this.jobConfiguration = tdd.getJobConfiguration();
this.taskConfiguration = tdd.getTaskConfiguration();
// ----------------------------------------------------------------
// Invokable setup
// ----------------------------------------------------------------
// Note: This has to be done *after* the readers and writers have
// been setup, because the invokable relies on them for I/O.
// ----------------------------------------------------------------
// Load and instantiate the invokable class
this.userCodeClassLoader = checkNotNull(userCodeClassLoader);
// Class of the task to run in this environment
Class<? extends AbstractInvokable> invokableClass;
try {
final String className = tdd.getInvokableClassName();
invokableClass = Class.forName(className, true, userCodeClassLoader).asSubclass(AbstractInvokable.class);
}
catch (Throwable t) {
throw new Exception("Could not load invokable class.", t);
}
try {
this.invokable = invokableClass.newInstance();
}
catch (Throwable t) {
throw new Exception("Could not instantiate the invokable class.", t);
}
this.invokable.setEnvironment(this);
this.invokable.registerInputOutput();
}
catch (Throwable t) {
throw new Exception("Error setting up runtime environment: " + t.getMessage(), t);
}
}
/**
* Returns the task invokable instance.
*/
public AbstractInvokable getInvokable() {
return this.invokable;
}
@Override
public JobID getJobID() {
return this.owner.getJobID();
}
@Override
public JobVertexID getJobVertexId() {
return this.owner.getVertexID();
}
@Override
public void run() {
// quick fail in case the task was cancelled while the thread was started
if (owner.isCanceledOrFailed()) {
owner.cancelingDone();
return;
}
try {
Thread.currentThread().setContextClassLoader(userCodeClassLoader);
invokable.invoke();
// Make sure, we enter the catch block when the task has been canceled
if (owner.isCanceledOrFailed()) {
throw new CancelTaskException("Task has been canceled or failed");
}
// Finish the produced partitions
if (producedPartitions != null) {
for (ResultPartition partition : producedPartitions) {
if (partition != null) {
partition.finish();
}
}
}
if (owner.isCanceledOrFailed()) {
throw new CancelTaskException();
}
// Finally, switch execution state to FINISHED and report to job manager
if (!owner.markAsFinished()) {
throw new Exception("Could *not* notify job manager that the task is finished.");
}
}
catch (Throwable t) {
if (!owner.isCanceledOrFailed()) {
// Perform clean up when the task failed and has been not canceled by the user
try {
invokable.cancel();
}
catch (Throwable t2) {
LOG.error("Error while canceling the task", t2);
}
}
// if we are already set as cancelled or failed (when failure is triggered externally),
// mark that the thread is done.
if (owner.isCanceledOrFailed() || t instanceof CancelTaskException) {
owner.cancelingDone();
}
else {
// failure from inside the task thread. notify the task of the failure
owner.markFailed(t);
}
}
}
/**
* Returns the thread, which is assigned to execute the user code.
*/
public Thread getExecutingThread() {
synchronized (this) {
if (executingThread == null) {
String name = owner.getTaskNameWithSubtasks();
if (LOG.isDebugEnabled()) {
name = name + " (" + owner.getExecutionId() + ")";
}
executingThread = new Thread(TASK_THREADS, this, name);
}
return executingThread;
}
}
public void cancelExecution() {
if (!canceled.compareAndSet(false, true)) {
return;
}
LOG.info("Canceling {} ({}).", owner.getTaskNameWithSubtasks(), owner.getExecutionId());
// Request user code to shut down
if (invokable != null) {
try {
invokable.cancel();
}
catch (Throwable e) {
LOG.error("Error while canceling the task.", e);
}
}
final Thread executingThread = this.executingThread;
if (executingThread != null) {
// interrupt the running thread and wait for it to die
executingThread.interrupt();
try {
executingThread.join(5000);
}
catch (InterruptedException e) {
}
if (!executingThread.isAlive()) {
return;
}
// Continuously interrupt the user thread until it changed to state CANCELED
while (executingThread != null && executingThread.isAlive()) {
LOG.warn("Task " + owner.getTaskNameWithSubtasks() + " did not react to cancelling signal. Sending repeated interrupt.");
if (LOG.isDebugEnabled()) {
StringBuilder bld = new StringBuilder("Task ").append(owner.getTaskNameWithSubtasks()).append(" is stuck in method:\n");
StackTraceElement[] stack = executingThread.getStackTrace();
for (StackTraceElement e : stack) {
bld.append(e).append('\n');
}
LOG.debug(bld.toString());
}
executingThread.interrupt();
try {
executingThread.join(1000);
}
catch (InterruptedException e) {
}
}
}
}
@Override
public ActorRef getJobManager() {
return jobManager;
}
@Override
public IOManager getIOManager() {
return ioManager;
}
@Override
public MemoryManager getMemoryManager() {
return memoryManager;
}
@Override
public BroadcastVariableManager getBroadcastVariableManager() {
return broadcastVariableManager;
}
@Override
public void reportAccumulators(Map<String, Accumulator<?, ?>> accumulators) {
AccumulatorEvent evt;
try {
evt = new AccumulatorEvent(getJobID(), accumulators);
}
catch (IOException e) {
throw new RuntimeException("Cannot serialize accumulators to send them to JobManager", e);
}
ReportAccumulatorResult accResult = new ReportAccumulatorResult(getJobID(), owner.getExecutionId(), evt);
jobManager.tell(accResult, ActorRef.noSender());
}
@Override
public ResultPartitionWriter getWriter(int index) {
checkElementIndex(index, writers.length, "Illegal environment writer request.");
return writers[checkElementIndex(index, writers.length)];
}
@Override
public ResultPartitionWriter[] getAllWriters() {
return writers;
}
@Override
public InputGate getInputGate(int index) {
checkElementIndex(index, inputGates.length);
return inputGates[index];
}
@Override
public SingleInputGate[] getAllInputGates() {
return inputGates;
}
public ResultPartition[] getProducedPartitions() {
return producedPartitions;
}
public SingleInputGate getInputGateById(IntermediateDataSetID id) {
return inputGatesById.get(id);
}
@Override
public Configuration getTaskConfiguration() {
return taskConfiguration;
}
@Override
public Configuration getJobConfiguration() {
return jobConfiguration;
}
@Override
public int getNumberOfSubtasks() {
return owner.getNumberOfSubtasks();
}
@Override
public int getIndexInSubtaskGroup() {
return owner.getSubtaskIndex();
}
@Override
public String getTaskName() {
return owner.getTaskName();
}
@Override
public InputSplitProvider getInputSplitProvider() {
return inputSplitProvider;
}
@Override
public String getTaskNameWithSubtasks() {
return owner.getTaskNameWithSubtasks();
}
@Override
public ClassLoader getUserClassLoader() {
return userCodeClassLoader;
}
public void addCopyTasksForCacheFile(Map<String, FutureTask<Path>> copyTasks) {
cacheCopyTasks.putAll(copyTasks);
}
public void addCopyTaskForCacheFile(String name, FutureTask<Path> copyTask) {
cacheCopyTasks.put(name, copyTask);
}
@Override
public Map<String, FutureTask<Path>> getCopyTask() {
return cacheCopyTasks;
}
}
......@@ -243,7 +243,7 @@ public class NetworkEnvironment {
public void registerTask(Task task) throws IOException {
final ResultPartition[] producedPartitions = task.getProducedPartitions();
final ResultPartitionWriter[] writers = task.getWriters();
final ResultPartitionWriter[] writers = task.getAllWriters();
if (writers.length != producedPartitions.length) {
throw new IllegalStateException("Unequal number of writers and partitions.");
......@@ -288,7 +288,7 @@ public class NetworkEnvironment {
}
// Setup the buffer pool for each buffer reader
final SingleInputGate[] inputGates = task.getInputGates();
final SingleInputGate[] inputGates = task.getAllInputGates();
for (SingleInputGate gate : inputGates) {
BufferPool bufferPool = null;
......@@ -329,10 +329,9 @@ public class NetworkEnvironment {
partitionManager.releasePartitionsProducedBy(executionId);
}
ResultPartitionWriter[] writers = task.getWriters();
ResultPartitionWriter[] writers = task.getAllWriters();
if (writers != null) {
for (ResultPartitionWriter writer : task.getWriters()) {
for (ResultPartitionWriter writer : writers) {
taskEventDispatcher.unregisterWriter(writer);
}
}
......@@ -344,7 +343,7 @@ public class NetworkEnvironment {
}
}
final SingleInputGate[] inputGates = task.getInputGates();
final SingleInputGate[] inputGates = task.getAllInputGates();
if (inputGates != null) {
for (SingleInputGate gate : inputGates) {
......
......@@ -18,7 +18,6 @@
package org.apache.flink.runtime.io.network.partition;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode;
......@@ -76,9 +75,8 @@ import static com.google.common.base.Preconditions.checkState;
public class ResultPartition implements BufferPoolOwner {
private static final Logger LOG = LoggerFactory.getLogger(ResultPartition.class);
/** The owning environment. Mainly for debug purposes. */
private final Environment owner;
private final String owningTaskName;
private final JobID jobId;
......@@ -120,7 +118,7 @@ public class ResultPartition implements BufferPoolOwner {
private long totalNumberOfBytes;
public ResultPartition(
Environment owner,
String owningTaskName,
JobID jobId,
ResultPartitionID partitionId,
ResultPartitionType partitionType,
......@@ -130,7 +128,7 @@ public class ResultPartition implements BufferPoolOwner {
IOManager ioManager,
IOMode defaultIoMode) {
this.owner = checkNotNull(owner);
this.owningTaskName = checkNotNull(owningTaskName);
this.jobId = checkNotNull(jobId);
this.partitionId = checkNotNull(partitionId);
this.partitionType = checkNotNull(partitionType);
......@@ -162,7 +160,7 @@ public class ResultPartition implements BufferPoolOwner {
// Initially, partitions should be consumed once before release.
pin();
LOG.debug("{}: Initialized {}", owner.getTaskNameWithSubtasks(), this);
LOG.debug("{}: Initialized {}", owningTaskName, this);
}
/**
......@@ -281,7 +279,7 @@ public class ResultPartition implements BufferPoolOwner {
*/
public void release() {
if (isReleased.compareAndSet(false, true)) {
LOG.debug("{}: Releasing {}.", owner.getTaskNameWithSubtasks(), this);
LOG.debug("{}: Releasing {}.", owningTaskName, this);
// Release all subpartitions
for (ResultSubpartition subpartition : subpartitions) {
......
......@@ -24,7 +24,6 @@ import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionLocation;
import org.apache.flink.runtime.event.task.AbstractEvent;
import org.apache.flink.runtime.event.task.TaskEvent;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
......@@ -101,8 +100,8 @@ public class SingleInputGate implements InputGate {
/** Lock object to guard partition requests and runtime channel updates. */
private final Object requestLock = new Object();
/** The owning environment. Mainly for debug purposes. */
private final Environment owner;
/** The name of the owning task, for logging purposes. */
private final String owningTaskName;
/**
* The ID of the consumed intermediate result. Each input gate consumes partitions of the
......@@ -153,12 +152,12 @@ public class SingleInputGate implements InputGate {
private int numberOfUninitializedChannels;
public SingleInputGate(
Environment owner,
String owningTaskName,
IntermediateDataSetID consumedResultId,
int consumedSubpartitionIndex,
int numberOfInputChannels) {
this.owner = checkNotNull(owner);
this.owningTaskName = checkNotNull(owningTaskName);
this.consumedResultId = checkNotNull(consumedResultId);
checkArgument(consumedSubpartitionIndex >= 0);
......@@ -265,7 +264,7 @@ public class SingleInputGate implements InputGate {
synchronized (requestLock) {
if (!isReleased) {
try {
LOG.debug("{}: Releasing {}.", owner.getTaskNameWithSubtasks(), this);
LOG.debug("{}: Releasing {}.", owningTaskName, this);
for (InputChannel inputChannel : inputChannels.values()) {
try {
......@@ -410,7 +409,7 @@ public class SingleInputGate implements InputGate {
* Creates an input gate and all of its input channels.
*/
public static SingleInputGate create(
Environment owner,
String owningTaskName,
InputGateDeploymentDescriptor igdd,
NetworkEnvironment networkEnvironment) {
......@@ -422,7 +421,7 @@ public class SingleInputGate implements InputGate {
final InputChannelDeploymentDescriptor[] icdd = checkNotNull(igdd.getInputChannelDeploymentDescriptors());
final SingleInputGate inputGate = new SingleInputGate(
owner, consumedResultId, consumedSubpartitionIndex, icdd.length);
owningTaskName, consumedResultId, consumedSubpartitionIndex, icdd.length);
// Create the input channels. There is one input channel for each consumed partition.
final InputChannel[] inputChannels = new InputChannel[icdd.length];
......
......@@ -1067,7 +1067,8 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
public DistributedRuntimeUDFContext createRuntimeContext(String taskName) {
Environment env = getEnvironment();
return new DistributedRuntimeUDFContext(taskName, env.getNumberOfSubtasks(),
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(), env.getCopyTask());
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(),
env.getDistributedCacheEntries());
}
// --------------------------------------------------------------------------------------------
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.taskmanager;
import akka.actor.ActorRef;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.accumulators.AccumulatorEvent;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.runtime.messages.accumulators.ReportAccumulatorResult;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.Future;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkArgument;
/**
* In implementation of the {@link Environment}.
*/
public class RuntimeEnvironment implements Environment {
private final JobID jobId;
private final JobVertexID jobVertexId;
private final ExecutionAttemptID executionId;
private final String taskName;
private final String taskNameWithSubtasks;
private final int subtaskIndex;
private final int parallelism;
private final Configuration jobConfiguration;
private final Configuration taskConfiguration;
private final ClassLoader userCodeClassLoader;
private final MemoryManager memManager;
private final IOManager ioManager;
private final BroadcastVariableManager bcVarManager;
private final InputSplitProvider splitProvider;
private final Map<String, Future<Path>> distCacheEntries;
private final ResultPartitionWriter[] writers;
private final InputGate[] inputGates;
private final ActorRef jobManagerActor;
// ------------------------------------------------------------------------
public RuntimeEnvironment(JobID jobId, JobVertexID jobVertexId, ExecutionAttemptID executionId,
String taskName, String taskNameWithSubtasks,
int subtaskIndex, int parallelism,
Configuration jobConfiguration, Configuration taskConfiguration,
ClassLoader userCodeClassLoader,
MemoryManager memManager, IOManager ioManager,
BroadcastVariableManager bcVarManager,
InputSplitProvider splitProvider,
Map<String, Future<Path>> distCacheEntries,
ResultPartitionWriter[] writers,
InputGate[] inputGates,
ActorRef jobManagerActor) {
checkArgument(parallelism > 0 && subtaskIndex >= 0 && subtaskIndex < parallelism);
this.jobId = checkNotNull(jobId);
this.jobVertexId = checkNotNull(jobVertexId);
this.executionId = checkNotNull(executionId);
this.taskName = checkNotNull(taskName);
this.taskNameWithSubtasks = checkNotNull(taskNameWithSubtasks);
this.subtaskIndex = subtaskIndex;
this.parallelism = parallelism;
this.jobConfiguration = checkNotNull(jobConfiguration);
this.taskConfiguration = checkNotNull(taskConfiguration);
this.userCodeClassLoader = checkNotNull(userCodeClassLoader);
this.memManager = checkNotNull(memManager);
this.ioManager = checkNotNull(ioManager);
this.bcVarManager = checkNotNull(bcVarManager);
this.splitProvider = checkNotNull(splitProvider);
this.distCacheEntries = checkNotNull(distCacheEntries);
this.writers = checkNotNull(writers);
this.inputGates = checkNotNull(inputGates);
this.jobManagerActor = checkNotNull(jobManagerActor);
}
// ------------------------------------------------------------------------
@Override
public JobID getJobID() {
return jobId;
}
@Override
public JobVertexID getJobVertexId() {
return jobVertexId;
}
@Override
public ExecutionAttemptID getExecutionId() {
return executionId;
}
@Override
public String getTaskName() {
return taskName;
}
@Override
public String getTaskNameWithSubtasks() {
return taskNameWithSubtasks;
}
@Override
public int getNumberOfSubtasks() {
return parallelism;
}
@Override
public int getIndexInSubtaskGroup() {
return subtaskIndex;
}
@Override
public Configuration getJobConfiguration() {
return jobConfiguration;
}
@Override
public Configuration getTaskConfiguration() {
return taskConfiguration;
}
@Override
public ClassLoader getUserClassLoader() {
return userCodeClassLoader;
}
@Override
public MemoryManager getMemoryManager() {
return memManager;
}
@Override
public IOManager getIOManager() {
return ioManager;
}
@Override
public BroadcastVariableManager getBroadcastVariableManager() {
return bcVarManager;
}
@Override
public InputSplitProvider getInputSplitProvider() {
return splitProvider;
}
@Override
public Map<String, Future<Path>> getDistributedCacheEntries() {
return distCacheEntries;
}
@Override
public ResultPartitionWriter getWriter(int index) {
return writers[index];
}
@Override
public ResultPartitionWriter[] getAllWriters() {
return writers;
}
@Override
public InputGate getInputGate(int index) {
return inputGates[index];
}
@Override
public InputGate[] getAllInputGates() {
return inputGates;
}
@Override
public void reportAccumulators(Map<String, Accumulator<?, ?>> accumulators) {
AccumulatorEvent evt;
try {
evt = new AccumulatorEvent(getJobID(), accumulators);
}
catch (IOException e) {
throw new RuntimeException("Cannot serialize accumulators to send them to JobManager", e);
}
ReportAccumulatorResult accResult = new ReportAccumulatorResult(jobId, executionId, evt);
jobManagerActor.tell(accResult, ActorRef.noSender());
}
@Override
public ActorRef getJobManager() {
return jobManagerActor;
}
}
......@@ -19,382 +19,755 @@
package org.apache.flink.runtime.taskmanager;
import akka.actor.ActorRef;
import akka.util.Timeout;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.blob.BlobKey;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
import org.apache.flink.runtime.execution.CancelTaskException;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.RuntimeEnvironment;
import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver;
import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.runtime.messages.ExecutionGraphMessages;
import org.apache.flink.runtime.messages.TaskMessages.UnregisterTask;
import org.apache.flink.runtime.profiling.TaskManagerProfiler;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.runtime.messages.TaskMessages;
import org.apache.flink.runtime.messages.TaskMessages.TaskInFinalState;
import org.apache.flink.runtime.messages.TaskManagerMessages.FatalError;
import org.apache.flink.runtime.state.StateHandle;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.FiniteDuration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
public class Task {
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
/**
* The Task represents one execution of a parallel subtask on a TaskManager.
* A Task wraps a Flink operator (which may be a user function) and
* runs it, providing all service necessary for example to consume input data,
* produce its results (intermediate result partitions) and communicate
* with the JobManager.
*
* <p>The Flink operators (implemented as subclasses of
* {@link org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable} have only data
* readers, -writers, and certain event callbacks. The task connects those to the
* network stack and actor messages, and tracks the state of the execution and
* handles exceptions.</p>
*
* <p>Tasks have no knowledge about how they relate to other tasks, or whether they
* are the first attempt to execute the task, or a repeated attempt. All of that
* is only known to the JobManager. All the task knows are its own runnable code,
* the task's configuration, and the IDs of the intermediate results to consume and
* produce (if any).</p>
*
* <p>Each Task is run by one dedicated thread.</p>
*/
public class Task implements Runnable {
/** The class logger. */
private static final Logger LOG = LoggerFactory.getLogger(Task.class);
/** The tread group that contains all task threads */
private static final ThreadGroup TASK_THREADS_GROUP = new ThreadGroup("Flink Task Threads");
/** For atomic state updates */
private static final AtomicReferenceFieldUpdater<Task, ExecutionState> STATE_UPDATER =
AtomicReferenceFieldUpdater.newUpdater(Task.class, ExecutionState.class, "executionState");
/** The log object used for debugging. */
private static final Logger LOG = LoggerFactory.getLogger(Task.class);
// --------------------------------------------------------------------------------------------
// ------------------------------------------------------------------------
// Constant fields that are part of the initial Task construction
// ------------------------------------------------------------------------
/** The job that the task belongs to */
private final JobID jobId;
/** The vertex in the JobGraph whose code the task executes */
private final JobVertexID vertexId;
/** The execution attempt of the parallel subtask */
private final ExecutionAttemptID executionId;
/** The index of the parallel subtask, in [0, numberOfSubtasks) */
private final int subtaskIndex;
private final int numberOfSubtasks;
private final ExecutionAttemptID executionId;
/** The number of parallel subtasks for the JobVertex/ExecutionJobVertex that this task belongs to */
private final int parallelism;
/** The name of the task */
private final String taskName;
/** The name of the task, including the subtask index and the parallelism */
private final String taskNameWithSubtask;
/** The job-wide configuration object */
private final Configuration jobConfiguration;
/** The task-specific configuration */
private final Configuration taskConfiguration;
/** The jar files used by this task */
private final List<BlobKey> requiredJarFiles;
/** The name of the class that holds the invokable code */
private final String nameOfInvokableClass;
/** The handle to the state that the operator was initialized with */
private final StateHandle operatorState;
/** The memory manager to be used by this task */
private final MemoryManager memoryManager;
/** The I/O manager to be used by this task */
private final IOManager ioManager;
/** The BroadcastVariableManager to be used by this task */
private final BroadcastVariableManager broadcastVariableManager;
private final ResultPartition[] producedPartitions;
private final ResultPartitionWriter[] writers;
private final SingleInputGate[] inputGates;
private final Map<IntermediateDataSetID, SingleInputGate> inputGatesById;
/** The TaskManager actor that spawned this task */
private final ActorRef taskManager;
private final List<ActorRef> executionListenerActors = new CopyOnWriteArrayList<ActorRef>();
/** The JobManager actor */
private final ActorRef jobManager;
/** All actors that want to be notified about changes in the task's execution state */
private final List<ActorRef> executionListenerActors;
/** The timeout for all ask operations on actors */
private final Timeout actorAskTimeout;
private final LibraryCacheManager libraryCache;
private final FileCache fileCache;
private final NetworkEnvironment network;
/** The environment (with the invokable) executed by this task */
private volatile RuntimeEnvironment environment;
/** The thread that executes the task */
private final Thread executingThread;
// ------------------------------------------------------------------------
// Fields that control the task execution
// ------------------------------------------------------------------------
private final AtomicBoolean invokableHasBeenCanceled = new AtomicBoolean(false);
/** The invokable of this task, if initialized */
private volatile AbstractInvokable invokable;
/** The current execution state of the task */
private volatile ExecutionState executionState = ExecutionState.DEPLOYING;
private volatile ExecutionState executionState = ExecutionState.CREATED;
/** The observed exception, in case the task execution failed */
private volatile Throwable failureCause;
// --------------------------------------------------------------------------------------------
/**
* <p><b>IMPORTANT:</b> This constructor may not start any work that would need to
* be undone in the case of a failing task deployment.</p>
*/
public Task(TaskDeploymentDescriptor tdd,
MemoryManager memManager,
IOManager ioManager,
NetworkEnvironment networkEnvironment,
BroadcastVariableManager bcVarManager,
ActorRef taskManagerActor,
ActorRef jobManagerActor,
FiniteDuration actorAskTimeout,
LibraryCacheManager libraryCache,
FileCache fileCache)
{
checkArgument(tdd.getNumberOfSubtasks() > 0);
checkArgument(tdd.getIndexInSubtaskGroup() >= 0);
checkArgument(tdd.getIndexInSubtaskGroup() < tdd.getNumberOfSubtasks());
this.jobId = checkNotNull(tdd.getJobID());
this.vertexId = checkNotNull(tdd.getVertexID());
this.executionId = checkNotNull(tdd.getExecutionId());
this.subtaskIndex = tdd.getIndexInSubtaskGroup();
this.parallelism = tdd.getNumberOfSubtasks();
this.taskName = checkNotNull(tdd.getTaskName());
this.taskNameWithSubtask = getTaskNameWithSubtask(taskName, subtaskIndex, parallelism);
this.jobConfiguration = checkNotNull(tdd.getJobConfiguration());
this.taskConfiguration = checkNotNull(tdd.getTaskConfiguration());
this.requiredJarFiles = checkNotNull(tdd.getRequiredJarFiles());
this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName());
this.operatorState = tdd.getOperatorStates();
this.memoryManager = checkNotNull(memManager);
this.ioManager = checkNotNull(ioManager);
this.broadcastVariableManager =checkNotNull(bcVarManager);
this.jobManager = checkNotNull(jobManagerActor);
this.taskManager = checkNotNull(taskManagerActor);
this.actorAskTimeout = new Timeout(checkNotNull(actorAskTimeout));
this.libraryCache = checkNotNull(libraryCache);
this.fileCache = checkNotNull(fileCache);
this.network = checkNotNull(networkEnvironment);
this.executionListenerActors = new CopyOnWriteArrayList<ActorRef>();
// create the reader and writer structures
final String taskNameWithSubtasksAndId =
Task.getTaskNameWithSubtaskAndID(taskName, subtaskIndex, parallelism, executionId);
List<ResultPartitionDeploymentDescriptor> partitions = tdd.getProducedPartitions();
List<InputGateDeploymentDescriptor> consumedPartitions = tdd.getInputGates();
// Produced intermediate result partitions
this.producedPartitions = new ResultPartition[partitions.size()];
this.writers = new ResultPartitionWriter[partitions.size()];
for (int i = 0; i < this.producedPartitions.length; i++) {
ResultPartitionDeploymentDescriptor desc = partitions.get(i);
ResultPartitionID partitionId = new ResultPartitionID(desc.getPartitionId(), executionId);
this.producedPartitions[i] = new ResultPartition(
taskNameWithSubtasksAndId,
jobId,
partitionId,
desc.getPartitionType(),
desc.getNumberOfSubpartitions(),
networkEnvironment.getPartitionManager(),
networkEnvironment.getPartitionConsumableNotifier(),
ioManager,
networkEnvironment.getDefaultIOMode());
this.writers[i] = new ResultPartitionWriter(this.producedPartitions[i]);
}
// Consumed intermediate result partitions
this.inputGates = new SingleInputGate[consumedPartitions.size()];
this.inputGatesById = new HashMap<IntermediateDataSetID, SingleInputGate>();
public Task(JobID jobId, JobVertexID vertexId, int taskIndex, int parallelism,
ExecutionAttemptID executionId, String taskName, ActorRef taskManager) {
for (int i = 0; i < this.inputGates.length; i++) {
SingleInputGate gate = SingleInputGate.create(
taskNameWithSubtasksAndId, consumedPartitions.get(i), networkEnvironment);
this.jobId = jobId;
this.vertexId = vertexId;
this.subtaskIndex = taskIndex;
this.numberOfSubtasks = parallelism;
this.executionId = executionId;
this.taskName = taskName;
this.taskManager = taskManager;
this.inputGates[i] = gate;
inputGatesById.put(gate.getConsumedResultId(), gate);
}
// finally, create the executing thread, but do not start it
executingThread = new Thread(TASK_THREADS_GROUP, this, taskNameWithSubtask);
}
/**
* Returns the ID of the job this task belongs to.
*/
// ------------------------------------------------------------------------
// Accessors
// ------------------------------------------------------------------------
public JobID getJobID() {
return this.jobId;
return jobId;
}
/**
* Returns the ID of this task vertex.
*/
public JobVertexID getVertexID() {
return this.vertexId;
public JobVertexID getJobVertexId() {
return vertexId;
}
/**
* Gets the index of the parallel subtask [0, parallelism).
*/
public int getSubtaskIndex() {
public ExecutionAttemptID getExecutionId() {
return executionId;
}
public int getIndexInSubtaskGroup() {
return subtaskIndex;
}
/**
* Gets the total number of subtasks of the task that this subtask belongs to.
*/
public int getNumberOfSubtasks() {
return numberOfSubtasks;
return parallelism;
}
/**
* Gets the ID of the execution attempt.
*/
public ExecutionAttemptID getExecutionId() {
return executionId;
public String getTaskName() {
return taskName;
}
public String getTaskNameWithSubtasks() {
return taskNameWithSubtask;
}
public Configuration getJobConfiguration() {
return jobConfiguration;
}
public Configuration getTaskConfiguration() {
return this.taskConfiguration;
}
public ResultPartitionWriter[] getAllWriters() {
return writers;
}
public SingleInputGate[] getAllInputGates() {
return inputGates;
}
public ResultPartition[] getProducedPartitions() {
return producedPartitions;
}
public SingleInputGate getInputGateById(IntermediateDataSetID id) {
return inputGatesById.get(id);
}
public Thread getExecutingThread() {
return executingThread;
}
// ------------------------------------------------------------------------
// Task Execution
// ------------------------------------------------------------------------
/**
* Returns the current execution state of the task.
* @return The current execution state of the task.
*/
public ExecutionState getExecutionState() {
return this.executionState;
}
public void setEnvironment(RuntimeEnvironment environment) {
this.environment = environment;
}
public RuntimeEnvironment getEnvironment() {
return environment;
}
/**
* Checks whether the task has failed, is canceled, or is being canceled at the moment.
* @return True is the task in state FAILED, CANCELING, or CANCELED, false otherwise.
*/
public boolean isCanceledOrFailed() {
return executionState == ExecutionState.CANCELING ||
executionState == ExecutionState.CANCELED ||
executionState == ExecutionState.FAILED;
}
public String getTaskName() {
if (LOG.isDebugEnabled()) {
return taskName + " (" + executionId + ")";
} else {
return taskName;
}
}
public String getTaskNameWithSubtasks() {
if (LOG.isDebugEnabled()) {
return this.taskName + " (" + (this.subtaskIndex + 1) + "/" + this.numberOfSubtasks +
") (" + executionId + ")";
} else {
return this.taskName + " (" + (this.subtaskIndex + 1) + "/" + this.numberOfSubtasks + ")";
}
}
/**
* If the task has failed, this method gets the exception that caused this task to fail.
* Otherwise this method returns null.
*
* @return The exception that caused the task to fail, or null, if the task has not failed.
*/
public Throwable getFailureCause() {
return failureCause;
}
// ----------------------------------------------------------------------------------------------------------------
// States and Transitions
// ----------------------------------------------------------------------------------------------------------------
/**
* Marks the task as finished. This succeeds, if the task was previously in the state
* "RUNNING", otherwise it fails. Failure indicates that the task was either
* canceled, or set to failed.
*
* @return True, if the task correctly enters the state FINISHED.
* Starts the task's thread.
*/
public boolean markAsFinished() {
if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, ExecutionState.FINISHED)) {
notifyObservers(ExecutionState.FINISHED, null);
unregisterTask();
return true;
}
else {
return false;
}
public void startTaskThread() {
executingThread.start();
}
public void markFailed(Throwable error) {
/**
* The core work method that bootstraps the task and executes it code
*/
public void run() {
// ----------------------------
// Initial State transition
// ----------------------------
while (true) {
ExecutionState current = this.executionState;
// if canceled, fine. we are done, and the jobmanager has been told
if (current == ExecutionState.CANCELED) {
if (current == ExecutionState.CREATED) {
if (STATE_UPDATER.compareAndSet(this, ExecutionState.CREATED, ExecutionState.DEPLOYING)) {
// success, we can start our work
break;
}
}
else if (current == ExecutionState.FAILED) {
// we were immediately failed. tell the TaskManager that we reached our final state
notifyFinalState();
return;
}
else if (current == ExecutionState.CANCELING) {
if (STATE_UPDATER.compareAndSet(this, ExecutionState.CANCELING, ExecutionState.CANCELED)) {
// we were immediately canceled. tell the TaskManager that we reached our final state
notifyFinalState();
return;
}
}
else {
throw new IllegalStateException("Invalid state for beginning of task operation");
}
}
// if canceling, we are done, but we cannot be sure that the jobmanager has been told.
// after all, we may have recognized our failure state before the cancelling and never sent a canceled
// message back
else if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
this.failureCause = error;
// all resource acquisitions and registrations from here on
// need to be undone in the end
notifyObservers(ExecutionState.FAILED, ExceptionUtils.stringifyException(error));
unregisterTask();
Map<String, Future<Path>> distributedCacheEntries = new HashMap<String, Future<Path>>();
return;
}
}
}
AbstractInvokable invokable = null;
public void cancelExecution() {
while (true) {
ExecutionState current = this.executionState;
try {
// ----------------------------
// Task Bootstrap - We periodically
// check for canceling as a shortcut
// ----------------------------
// if the task is already canceled (or canceling) or finished or failed,
// then we need not do anything
if (current == ExecutionState.FINISHED || current == ExecutionState.CANCELED ||
current == ExecutionState.CANCELING || current == ExecutionState.FAILED) {
return;
}
// first of all, get a user-code classloader
// this may involve downloading the job's JAR files and/or classes
LOG.info("Loading JAR files for task " + taskNameWithSubtask);
final ClassLoader userCodeClassLoader = createUserCodeClassloader(libraryCache);
if (current == ExecutionState.DEPLOYING) {
// directly set to canceled
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) {
// now load the task's invokable code
invokable = loadAndInstantiateInvokable(userCodeClassLoader, nameOfInvokableClass);
notifyObservers(ExecutionState.CANCELED, null);
unregisterTask();
return;
}
if (isCanceledOrFailed()) {
throw new CancelTaskException();
}
else if (current == ExecutionState.RUNNING) {
// go to canceling and perform the actual task canceling
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELING)) {
notifyObservers(ExecutionState.CANCELING, null);
try {
this.environment.cancelExecution();
}
catch (Throwable e) {
LOG.error("Error while cancelling the task.", e);
}
return;
// ----------------------------------------------------------------
// register the task with the network stack
// this operation may fail if the system does not have enough
// memory to run the necessary data exchanges
// the registration must also strictly be undone
// ----------------------------------------------------------------
LOG.info("Registering task at network: " + this);
network.registerTask(this);
// next, kick off the background copying of files for the distributed cache
try {
for (Map.Entry<String, DistributedCache.DistributedCacheEntry> entry :
DistributedCache.readFileInfoFromConfig(jobConfiguration))
{
LOG.info("Obtaining local cache file for '" + entry.getKey() + '\'');
Future<Path> cp = fileCache.createTmpFile(entry.getKey(), entry.getValue(), jobId);
distributedCacheEntries.put(entry.getKey(), cp);
}
}
else {
throw new RuntimeException("unexpected state for cancelling: " + current);
catch (Exception e) {
throw new Exception("Exception while adding files to distributed cache.", e);
}
}
}
/**
* Sets the tasks to be cancelled and reports a failure back to the master.
*/
public void failExternally(Throwable cause) {
while (true) {
ExecutionState current = this.executionState;
// if the task is already canceled (or canceling) or finished or failed,
// then we need not do anything
if (current == ExecutionState.CANCELED || current == ExecutionState.CANCELING || current == ExecutionState.FAILED) {
return;
if (isCanceledOrFailed()) {
throw new CancelTaskException();
}
if (current == ExecutionState.FINISHED) {
// Set state to failed in order to correctly unregister task from network environment
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
notifyObservers(ExecutionState.FAILED, null);
// ----------------------------------------------------------------
// call the user code initialization methods
// ----------------------------------------------------------------
return;
TaskInputSplitProvider splitProvider = new TaskInputSplitProvider(jobManager,
jobId, vertexId, executionId, userCodeClassLoader, actorAskTimeout);
Environment env = new RuntimeEnvironment(jobId, vertexId, executionId,
taskName, taskNameWithSubtask, subtaskIndex, parallelism,
jobConfiguration, taskConfiguration,
userCodeClassLoader, memoryManager, ioManager, broadcastVariableManager,
splitProvider, distributedCacheEntries,
writers, inputGates, jobManager);
// let the task code create its readers and writers
invokable.setEnvironment(env);
try {
invokable.registerInputOutput();
}
catch (Exception e) {
throw new Exception("Call to registerInputOutput() of invokable failed", e);
}
// the very last thing before the actual execution starts running is to inject
// the state into the task. the state is non-empty if this is an execution
// of a task that failed but had backuped state from a checkpoint
if (operatorState != null) {
if (invokable instanceof OperatorStateCarrier) {
((OperatorStateCarrier) invokable).injectState(operatorState);
}
else {
throw new IllegalStateException("Found operator state for a non-stateful task invokable");
}
}
if (current == ExecutionState.DEPLOYING) {
// directly set to canceled
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
this.failureCause = cause;
// ----------------------------------------------------------------
// actual task core work
// ----------------------------------------------------------------
notifyObservers(ExecutionState.FAILED, null);
unregisterTask();
return;
// we must make strictly sure that the invokable is accessible to teh cancel() call
// by the time we switched to running.
this.invokable = invokable;
// switch to the RUNNING state, if that fails, we have been canceled/failed in the meantime
if (!STATE_UPDATER.compareAndSet(this, ExecutionState.DEPLOYING, ExecutionState.RUNNING)) {
throw new CancelTaskException();
}
// notify everyone that we switched to running. especially the TaskManager needs
// to know this!
notifyObservers(ExecutionState.RUNNING, null);
taskManager.tell(new TaskMessages.UpdateTaskExecutionState(
new TaskExecutionState(jobId, executionId, ExecutionState.RUNNING)), ActorRef.noSender());
// make sure the user code classloader is accessible thread-locally
executingThread.setContextClassLoader(userCodeClassLoader);
// run the invokable
invokable.invoke();
// make sure, we enter the catch block if the task leaves the invoke() method due
// to the fact that it has been canceled
if (isCanceledOrFailed()) {
throw new CancelTaskException();
}
// ----------------------------------------------------------------
// finalization of a successful execution
// ----------------------------------------------------------------
// finish the produced partitions. if this fails, we consider the execution failed.
for (ResultPartition partition : producedPartitions) {
if (partition != null) {
partition.finish();
}
}
else if (current == ExecutionState.RUNNING) {
// go to canceling and perform the actual task canceling
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
try {
this.environment.cancelExecution();
// try to mark the task as finished
// if that fails, the task was canceled/failed in the meantime
if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, ExecutionState.FINISHED)) {
notifyObservers(ExecutionState.FINISHED, null);
}
else {
throw new CancelTaskException();
}
}
catch (Throwable t) {
// ----------------------------------------------------------------
// the execution failed. either the invokable code properly failed, or
// an exception was thrown as a side effect of cancelling
// ----------------------------------------------------------------
try {
// transition into our final state. we should be either in RUNNING, CANCELING, or FAILED
// loop for multiple retries during concurrent state changes via calls to cancel() or
// to failExternally()
while (true) {
ExecutionState current = this.executionState;
if (current == ExecutionState.RUNNING || current == ExecutionState.DEPLOYING) {
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
// proper failure of the task. record the exception as the root cause
failureCause = t;
notifyObservers(ExecutionState.FAILED, t);
// in case of an exception during execution, we still call "cancel()" on the task
if (invokable != null && this.invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) {
try {
invokable.cancel();
}
catch (Throwable t2) {
LOG.error("Error while canceling task " + taskNameWithSubtask, t2);
}
}
break;
}
}
catch (Throwable e) {
LOG.error("Error while cancelling the task.", e);
else if (current == ExecutionState.CANCELING) {
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) {
notifyObservers(ExecutionState.CANCELED, null);
break;
}
}
else if (current == ExecutionState.FAILED) {
// in state failed already, no transition necessary any more
break;
}
// unexpected state, go to failed
else if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) {
LOG.error("Unexpected state in Task during an exception: " + current);
break;
}
// else fall through the loop and
}
}
catch (Throwable tt) {
String message = "FATAL - exception in task exception handler";
LOG.error(message, tt);
notifyFatalError(message, tt);
}
}
finally {
try {
LOG.info("Freeing task resources for " + taskNameWithSubtask);
// free the network resources
network.unregisterTask(this);
if (invokable != null) {
memoryManager.releaseAll(invokable);
}
this.failureCause = cause;
// remove all of the tasks library resources
libraryCache.unregisterTask(jobId, executionId);
notifyObservers(ExecutionState.FAILED, null);
unregisterTask();
// remove all files in the distributed cache
removeCachedFiles(distributedCacheEntries, fileCache);
return;
}
notifyFinalState();
}
else {
throw new RuntimeException("unexpected state for failing the task: " + current);
catch (Throwable t) {
// an error in the resource cleanup is fatal
String message = "FATAL - exception in task resource cleanup";
LOG.error(message, t);
notifyFatalError(message, t);
}
}
}
public void cancelingDone() {
while (true) {
ExecutionState current = this.executionState;
private ClassLoader createUserCodeClassloader(LibraryCacheManager libraryCache) throws Exception {
long startDownloadTime = System.currentTimeMillis();
if (current == ExecutionState.CANCELED || current == ExecutionState.FAILED) {
return;
}
if (!(current == ExecutionState.RUNNING || current == ExecutionState.CANCELING)) {
LOG.error(String.format("Unexpected state transition in Task: %s -> %s", current, ExecutionState.CANCELED));
}
// triggers the download of all missing jar files from the job manager
libraryCache.registerTask(jobId, executionId, requiredJarFiles);
if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) {
notifyObservers(ExecutionState.CANCELED, null);
unregisterTask();
return;
}
LOG.debug("Register task {} at library cache manager took {} milliseconds",
executionId, System.currentTimeMillis() - startDownloadTime);
ClassLoader userCodeClassLoader = libraryCache.getClassLoader(jobId);
if (userCodeClassLoader == null) {
throw new Exception("No user code classloader available.");
}
return userCodeClassLoader;
}
/**
* Starts the execution of this task.
*/
public boolean startExecution() {
LOG.info("Starting execution of task {}", this.getTaskName());
if (STATE_UPDATER.compareAndSet(this, ExecutionState.DEPLOYING, ExecutionState.RUNNING)) {
final Thread thread = this.environment.getExecutingThread();
thread.start();
return true;
private AbstractInvokable loadAndInstantiateInvokable(ClassLoader classLoader, String className) throws Exception {
Class<? extends AbstractInvokable> invokableClass;
try {
invokableClass = Class.forName(className, true, classLoader)
.asSubclass(AbstractInvokable.class);
}
else {
return false;
catch (Throwable t) {
throw new Exception("Could not load the task's invokable class.", t);
}
try {
return invokableClass.newInstance();
}
catch (Throwable t) {
throw new Exception("Could not instantiate the task's invokable class.", t);
}
}
/**
* Unregisters the task from the central memory manager.
*/
public void unregisterMemoryManager(MemoryManager memoryManager) {
RuntimeEnvironment env = this.environment;
if (memoryManager != null && env != null) {
memoryManager.releaseAll(env.getInvokable());
private void removeCachedFiles(Map<String, Future<Path>> entries, FileCache fileCache) {
// cancel and release all distributed cache files
try {
for (Map.Entry<String, Future<Path>> entry : entries.entrySet()) {
String name = entry.getKey();
try {
fileCache.deleteTmpFile(name, jobId);
}
catch (Exception e) {
// unpleasant, but we continue
LOG.error("Distributed Cache could not remove cached file registered under '"
+ name + "'.", e);
}
}
}
catch (Throwable t) {
LOG.error("Error while removing cached local files from distributed cache.");
}
}
protected void unregisterTask() {
taskManager.tell(new UnregisterTask(executionId), ActorRef.noSender());
private void notifyFinalState() {
taskManager.tell(new TaskInFinalState(executionId), ActorRef.noSender());
}
// -----------------------------------------------------------------------------------------------------------------
// Task Profiling
// -----------------------------------------------------------------------------------------------------------------
private void notifyFatalError(String message, Throwable cause) {
taskManager.tell(new FatalError(message, cause), ActorRef.noSender());
}
/**
* Registers the task manager profiler with the task.
*/
public void registerProfiler(TaskManagerProfiler taskManagerProfiler, Configuration jobConfiguration) {
taskManagerProfiler.registerTask(this, jobConfiguration);
// ----------------------------------------------------------------------------------------------------------------
// Canceling / Failing the task from the outside
// ----------------------------------------------------------------------------------------------------------------
public void cancelExecution() {
LOG.info("Attempting to cancel task " + taskNameWithSubtask);
if (cancelOrFailAndCancelInvokable(ExecutionState.CANCELING)) {
notifyObservers(ExecutionState.CANCELING, null);
}
}
/**
* Unregisters the task from the task manager profiler.
* Sets the tasks to be cancelled and reports a failure back to the master.
*/
public void unregisterProfiler(TaskManagerProfiler taskManagerProfiler) {
if (taskManagerProfiler != null) {
taskManagerProfiler.unregisterTask(this.executionId);
public void failExternally(Throwable cause) {
LOG.info("Attempting to fail task externally " + taskNameWithSubtask);
if (cancelOrFailAndCancelInvokable(ExecutionState.FAILED)) {
failureCause = cause;
notifyObservers(ExecutionState.FAILED, cause);
}
}
// ------------------------------------------------------------------------
// Intermediate result partitions
// ------------------------------------------------------------------------
public SingleInputGate[] getInputGates() {
return environment != null ? environment.getAllInputGates() : null;
}
private boolean cancelOrFailAndCancelInvokable(ExecutionState targetState) {
while (true) {
ExecutionState current = this.executionState;
public ResultPartitionWriter[] getWriters() {
return environment != null ? environment.getAllWriters() : null;
}
// if the task is already canceled (or canceling) or finished or failed,
// then we need not do anything
if (current.isTerminal() || current == ExecutionState.CANCELING) {
return false;
}
public ResultPartition[] getProducedPartitions() {
return environment != null ? environment.getProducedPartitions() : null;
if (current == ExecutionState.DEPLOYING || current == ExecutionState.CREATED) {
if (STATE_UPDATER.compareAndSet(this, current, targetState)) {
// if we manage this state transition, then the invokable gets never called
// we need not call cancel on it
return true;
}
}
else if (current == ExecutionState.RUNNING) {
if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, targetState)) {
// we are canceling / failing out of the running state
// we need to cancel the invokable
if (invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) {
LOG.info("Triggering cancellation of task code {} ({}).", taskNameWithSubtask, executionId);
// because the canceling may block on user code, we cancel from a separate thread
Runnable canceler = new TaskCanceler(LOG, invokable, executingThread, taskNameWithSubtask);
Thread cancelThread = new Thread(executingThread.getThreadGroup(), canceler,
"Canceler for " + taskNameWithSubtask);
cancelThread.start();
}
return true;
}
}
else {
throw new IllegalStateException("Unexpected task state: " + current);
}
}
}
// --------------------------------------------------------------------------------------------
// State Listeners
// --------------------------------------------------------------------------------------------
// ------------------------------------------------------------------------
// State Listeners
// ------------------------------------------------------------------------
public void registerExecutionListener(ActorRef listener) {
executionListenerActors.add(listener);
......@@ -404,25 +777,146 @@ public class Task {
executionListenerActors.remove(listener);
}
private void notifyObservers(ExecutionState newState, String message) {
if (LOG.isInfoEnabled()) {
LOG.info(getTaskNameWithSubtasks() + " switched to " + newState + (message == null ? "" : " : " + message));
private void notifyObservers(ExecutionState newState, Throwable error) {
if (error == null) {
LOG.info(taskNameWithSubtask + " switched to " + newState);
}
else {
LOG.info(taskNameWithSubtask + " switched to " + newState + " with exception.", error);
}
TaskExecutionState stateUpdate = new TaskExecutionState(jobId, executionId, newState, error);
TaskMessages.UpdateTaskExecutionState actorMessage = new
TaskMessages.UpdateTaskExecutionState(stateUpdate);
for (ActorRef listener : executionListenerActors) {
listener.tell(new ExecutionGraphMessages.ExecutionStateChanged(
jobId, vertexId, taskName, numberOfSubtasks, subtaskIndex,
executionId, newState, System.currentTimeMillis(), message),
ActorRef.noSender());
listener.tell(actorMessage, ActorRef.noSender());
}
}
// --------------------------------------------------------------------------------------------
// Utilities
// --------------------------------------------------------------------------------------------
// ------------------------------------------------------------------------
// Notifications on the invokable
// ------------------------------------------------------------------------
public void triggerCheckpointBarrier(final long checkpointID) {
AbstractInvokable invokabe = this.invokable;
if (executionState == ExecutionState.RUNNING && invokabe != null) {
if (invokabe instanceof BarrierTransceiver) {
final BarrierTransceiver barrierTransceiver = (BarrierTransceiver) invokabe;
final Logger logger = LOG;
Thread caller = new Thread("Barrier emitter") {
@Override
public void run() {
try {
barrierTransceiver.broadcastBarrierFromSource(checkpointID);
}
catch (Throwable t) {
logger.error("Error while triggering checkpoint barriers", t);
}
}
};
caller.setDaemon(true);
caller.start();
}
else {
LOG.error("Task received a checkpoint request, but is not a checkpointing task - "
+ taskNameWithSubtask);
}
}
else {
LOG.debug("Ignoring request to trigger a checkpoint barrier");
}
}
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
@Override
public String toString() {
return getTaskNameWithSubtasks() + " [" + executionState + ']';
}
// ------------------------------------------------------------------------
// Task Names
// ------------------------------------------------------------------------
public static String getTaskNameWithSubtask(String name, int subtask, int numSubtasks) {
return name + " (" + (subtask+1) + '/' + numSubtasks + ')';
}
public static String getTaskNameWithSubtaskAndID(String name, int subtask, int numSubtasks, ExecutionAttemptID id) {
return name + " (" + (subtask+1) + '/' + numSubtasks + ") (" + id + ')';
}
/**
* This runner calls cancel() on the invokable and periodically interrupts the
* thread until it has terminated.
*/
private static class TaskCanceler implements Runnable {
private final Logger logger;
private final AbstractInvokable invokable;
private final Thread executer;
private final String taskName;
public TaskCanceler(Logger logger, AbstractInvokable invokable, Thread executer, String taskName) {
this.logger = logger;
this.invokable = invokable;
this.executer = executer;
this.taskName = taskName;
}
@Override
public void run() {
try {
// the user-defined cancel method may throw errors.
// we need do continue despite that
try {
invokable.cancel();
}
catch (Throwable t) {
logger.error("Error while canceling the task", t);
}
// interrupt the running thread initially
executer.interrupt();
try {
executer.join(10000);
}
catch (InterruptedException e) {
// we can ignore this
}
// it is possible that the user code does not react immediately. for that
// reason, we spawn a separate thread that repeatedly interrupts the user code until
// it exits
while (executer.isAlive()) {
// build the stack trace of where the thread is stuck, for the log
StringBuilder bld = new StringBuilder();
StackTraceElement[] stack = executer.getStackTrace();
for (StackTraceElement e : stack) {
bld.append(e).append('\n');
}
logger.warn("Task '{}' did not react to cancelling signal, but is stuck in method:\n {}",
taskName, bld.toString());
executer.interrupt();
try {
executer.join(5000);
}
catch (InterruptedException e) {
// we can ignore this
}
}
}
catch (Throwable t) {
logger.error("Error in the task canceler", t);
}
}
}
}
......@@ -24,7 +24,15 @@ import org.apache.flink.runtime.instance.InstanceID
* Miscellaneous actor messages exchanged with the TaskManager.
*/
object TaskManagerMessages {
/**
* This message informs the TaskManager about a fatal error that prevents
* it from continuing.
*
* @param description The description of the problem
*/
case class FatalError(description: String, cause: Throwable)
/**
* Tells the task manager to send a heartbeat message to the job manager.
*/
......@@ -49,7 +57,7 @@ object TaskManagerMessages {
// --------------------------------------------------------------------------
// Utility messages used for notifications during TaskManager startup
// Reporting the current TaskManager stack trace
// --------------------------------------------------------------------------
/**
......
......@@ -67,12 +67,12 @@ object TaskMessages {
extends TaskMessage
/**
* Unregister the task identified by [[executionID]] from the TaskManager.
* Sent to the TaskManager by futures and callbacks.
* Notifies the TaskManager that the task has reached its final state,
* either FINISHED, CANCELED, or FAILED.
*
* @param executionID The task's execution attempt ID.
*/
case class UnregisterTask(executionID: ExecutionAttemptID)
case class TaskInFinalState(executionID: ExecutionAttemptID)
extends TaskMessage
......
......@@ -21,7 +21,7 @@ package org.apache.flink.runtime.taskmanager
import java.io.{File, IOException}
import java.net.{InetAddress, InetSocketAddress}
import java.util
import java.util.concurrent.{TimeUnit, FutureTask}
import java.util.concurrent.TimeUnit
import java.lang.reflect.Method
import java.lang.management.{GarbageCollectorMXBean, ManagementFactory, MemoryMXBean}
......@@ -36,16 +36,13 @@ import com.codahale.metrics.jvm.{MemoryUsageGaugeSet, GarbageCollectorMetricSet}
import com.fasterxml.jackson.databind.ObjectMapper
import grizzled.slf4j.Logger
import org.apache.flink.api.common.cache.DistributedCache
import org.apache.flink.configuration._
import org.apache.flink.core.fs.Path
import org.apache.flink.runtime.{ActorSynchronousLogging, ActorLogMessages}
import org.apache.flink.runtime.akka.AkkaUtils
import org.apache.flink.runtime.blob.{BlobService, BlobCache}
import org.apache.flink.runtime.broadcast.BroadcastVariableManager
import org.apache.flink.runtime.deployment.{InputChannelDeploymentDescriptor, TaskDeploymentDescriptor}
import org.apache.flink.runtime.execution.librarycache.{BlobLibraryCacheManager, FallbackLibraryCacheManager, LibraryCacheManager}
import org.apache.flink.runtime.execution.{CancelTaskException, ExecutionState, RuntimeEnvironment}
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID
import org.apache.flink.runtime.filecache.FileCache
import org.apache.flink.runtime.instance.{HardwareDescription, InstanceConnectionInfo, InstanceID}
......@@ -54,7 +51,6 @@ import org.apache.flink.runtime.io.disk.iomanager.{IOManager, IOManagerAsync}
import org.apache.flink.runtime.io.network.NetworkEnvironment
import org.apache.flink.runtime.io.network.netty.NettyConfig
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID
import org.apache.flink.runtime.jobgraph.tasks.{OperatorStateCarrier,BarrierTransceiver}
import org.apache.flink.runtime.jobmanager.JobManager
import org.apache.flink.runtime.memorymanager.{MemoryManager, DefaultMemoryManager}
import org.apache.flink.runtime.messages.CheckpointingMessages.{CheckpointingMessage, BarrierReq}
......@@ -67,9 +63,6 @@ import org.apache.flink.runtime.process.ProcessReaper
import org.apache.flink.runtime.security.SecurityUtils
import org.apache.flink.runtime.security.SecurityUtils.FlinkSecuredRunner
import org.apache.flink.runtime.util.{MathUtils, EnvironmentInformation}
import org.apache.flink.util.ExceptionUtils
import org.slf4j.LoggerFactory
import scala.concurrent._
import scala.concurrent.duration._
......@@ -136,13 +129,13 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
protected val resources = HardwareDescription.extractFromSystem(memoryManager.getMemorySize)
/** Registry of all tasks currently executed by this TaskManager */
protected val runningTasks = new util.concurrent.ConcurrentHashMap[ExecutionAttemptID, Task]()
protected val runningTasks = new java.util.HashMap[ExecutionAttemptID, Task]()
/** Handler for shared broadcast variables (shared between multiple Tasks) */
protected val bcVarManager = new BroadcastVariableManager()
/** Handler for distributed files cached by this TaskManager */
protected val fileCache = new FileCache()
protected val fileCache = new FileCache(config.configuration)
/** Registry of metrics periodically transmitted to the JobManager */
private val metricRegistry = TaskManager.createMetricsRegistry()
......@@ -282,6 +275,9 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
case Disconnect(msg) =>
handleJobManagerDisconnect(sender(), "JobManager requested disconnect: " + msg)
case FatalError(message, cause) =>
killTaskManagerFatal(message, cause)
}
/**
......@@ -344,12 +340,13 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
// state transition
case updateMsg @ UpdateTaskExecutionState(taskExecutionState: TaskExecutionState) =>
// we receive these from our tasks and forward them to the JobManager
currentJobManager foreach {
jobManager => {
val futureResponse = (jobManager ? updateMsg)(askTimeout)
val executionID = taskExecutionState.getID
val executionState = taskExecutionState.getExecutionState
futureResponse.mapTo[Boolean].onComplete {
// IMPORTANT: In the future callback, we cannot directly modify state
......@@ -359,21 +356,16 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
self ! FailTask(executionID,
new Exception("Task has been cancelled on the JobManager."))
}
if (!result || executionState.isTerminal) {
self ! UnregisterTask(executionID)
}
case Failure(t) =>
self ! FailTask(executionID, new Exception(
"Failed to send ExecutionStateChange notification to JobManager"))
self ! UnregisterTask(executionID)
}(context.dispatcher)
}
}
// removes the task from the TaskManager and frees all its resources
case UnregisterTask(executionID) =>
case TaskInFinalState(executionID) =>
unregisterTaskAndNotifyFinalState(executionID)
// starts a new task on the TaskManager
......@@ -383,35 +375,22 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
// marks a task as failed for an external reason
// external reasons are reasons other than the task code itself throwing an exception
case FailTask(executionID, cause) =>
Option(runningTasks.get(executionID)) match {
case Some(task) =>
// execute failing operation concurrently
implicit val executor = context.dispatcher
Future {
task.failExternally(cause)
}.onFailure{
case t: Throwable => log.error(s"Could not fail task ${task} externally.", t)
}
case None =>
val task = runningTasks.get(executionID)
if (task != null) {
task.failExternally(cause)
} else {
log.debug(s"Cannot find task to fail for execution ${executionID})")
}
// cancels a task
case CancelTask(executionID) =>
Option(runningTasks.get(executionID)) match {
case Some(task) =>
// execute cancel operation concurrently
implicit val executor = context.dispatcher
Future {
task.cancelExecution()
}.onFailure{
case t: Throwable => log.error("Could not cancel task " + task, t)
}
sender ! new TaskOperationResult(executionID, true)
case None =>
sender ! new TaskOperationResult(executionID, false,
val task = runningTasks.get(executionID)
if (task != null) {
task.cancelExecution()
sender ! new TaskOperationResult(executionID, true)
} else {
log.debug(s"Cannot find task to cancel for execution ${executionID})")
sender ! new TaskOperationResult(executionID, false,
"No task with that execution ID was found.")
}
}
......@@ -430,24 +409,11 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
log.debug(s"[FT-TaskManager] Barrier $checkpointID request received " +
s"for attempt $attemptID.")
Option(runningTasks.get(attemptID)) match {
case Some(i) =>
if (i.getExecutionState == ExecutionState.RUNNING) {
i.getEnvironment.getInvokable match {
case barrierTransceiver: BarrierTransceiver =>
new Thread(new Runnable {
override def run(): Unit =
barrierTransceiver.broadcastBarrierFromSource(checkpointID)
}).start()
case _ => log.error("Taskmanager received a checkpoint request for " +
s"non-checkpointing task $attemptID.")
}
}
case None =>
// may always happen in case of canceled/finished tasks
log.debug(s"Taskmanager received a checkpoint request for unknown task $attemptID.")
val task = runningTasks.get(attemptID)
if (task != null) {
task.triggerCheckpointBarrier(checkpointID)
} else {
log.debug(s"Taskmanager received a checkpoint request for unknown task $attemptID.")
}
// unknown checkpoint message
......@@ -770,8 +736,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
}
}
}
// --------------------------------------------------------------------------
// Task Operations
// --------------------------------------------------------------------------
......@@ -784,130 +749,46 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
* @param tdd TaskDeploymentDescriptor describing the task to be executed on this [[TaskManager]]
*/
private def submitTask(tdd: TaskDeploymentDescriptor): Unit = {
val slot = tdd.getTargetSlotNumber
if (!isConnected) {
sender ! Failure(
new IllegalStateException("TaskManager is not associated with a JobManager.")
)
} else if (slot < 0 || slot >= numberOfSlots) {
sender ! Failure(new Exception(s"Target slot $slot does not exist on TaskManager."))
} else {
sender ! Acknowledge
Future {
initializeTask(tdd)
}(context.dispatcher)
}
}
/** Sets up a [[org.apache.flink.runtime.execution.RuntimeEnvironment]] for the task and starts
* its execution in a separate thread.
*
* @param tdd TaskDeploymentDescriptor describing the task to be executed on this [[TaskManager]]
*/
private def initializeTask(tdd: TaskDeploymentDescriptor): Unit ={
val jobID = tdd.getJobID
val vertexID = tdd.getVertexID
val executionID = tdd.getExecutionId
val taskIndex = tdd.getIndexInSubtaskGroup
val numSubtasks = tdd.getNumberOfSubtasks
var startRegisteringTask = 0L
var task: Task = null
try {
val userCodeClassLoader = libraryCacheManager match {
case Some(manager) =>
if (log.isDebugEnabled) {
startRegisteringTask = System.currentTimeMillis()
}
// triggers the download of all missing jar files from the job manager
manager.registerTask(jobID, executionID, tdd.getRequiredJarFiles)
if (log.isDebugEnabled) {
log.debug(s"Register task $executionID at library cache manager " +
s"took ${(System.currentTimeMillis() - startRegisteringTask) / 1000.0}s")
}
manager.getClassLoader(jobID)
case None => throw new IllegalStateException("There is no valid library cache manager.")
}
if (userCodeClassLoader == null) {
throw new RuntimeException("No user code Classloader available.")
}
task = new Task(jobID, vertexID, taskIndex, numSubtasks, executionID,
tdd.getTaskName, self)
Option(runningTasks.put(executionID, task)) match {
case Some(_) => throw new RuntimeException(
s"TaskManager contains already a task with executionID $executionID.")
// grab some handles and sanity check on the fly
val jobManagerActor = currentJobManager match {
case Some(jm) => jm
case None =>
throw new IllegalStateException("TaskManager is not associated with a JobManager.")
}
val env = currentJobManager match {
case Some(jobManager) =>
val splitProvider = new TaskInputSplitProvider(jobManager, jobID, vertexID,
executionID, userCodeClassLoader, askTimeout)
new RuntimeEnvironment(jobManager, task, tdd, userCodeClassLoader,
memoryManager, ioManager, splitProvider, bcVarManager, network)
case None => throw new IllegalStateException(
"TaskManager has not yet been registered at a JobManager.")
}
task.setEnvironment(env)
//inject operator state
if (tdd.getOperatorStates != null) {
task.getEnvironment.getInvokable match {
case opStateCarrier: OperatorStateCarrier =>
opStateCarrier.injectState(tdd.getOperatorStates)
}
val libCache = libraryCacheManager match {
case Some(manager) => manager
case None => throw new IllegalStateException("There is no valid library cache manager.")
}
// register the task with the network stack and profiles
log.info(s"Register task $task.")
network.registerTask(task)
val cpTasks = new util.HashMap[String, FutureTask[Path]]()
for (entry <- DistributedCache.readFileInfoFromConfig(tdd.getJobConfiguration).asScala) {
val cp = fileCache.createTmpFile(entry.getKey, entry.getValue, jobID)
cpTasks.put(entry.getKey, cp)
val slot = tdd.getTargetSlotNumber
if (slot < 0 || slot >= numberOfSlots) {
throw new IllegalArgumentException(s"Target slot $slot does not exist on TaskManager.")
}
env.addCopyTasksForCacheFile(cpTasks)
if (!task.startExecution()) {
throw new RuntimeException("Cannot start task. Task was canceled or failed.")
// create the task. this does not grab any TaskManager resources or download
// and libraries - the operation does not block
val execId = tdd.getExecutionId
val task = new Task(tdd, memoryManager, ioManager, network, bcVarManager,
self, jobManagerActor, config.timeout, libCache, fileCache)
// add the task to the map
val prevTask = runningTasks.put(execId, task)
if (prevTask != null) {
// already have a task for that ID, put if back and report an error
runningTasks.put(execId, prevTask)
throw new IllegalStateException("TaskManager already contains a task for id " + execId)
}
self ! UpdateTaskExecutionState(
new TaskExecutionState(jobID, executionID, ExecutionState.RUNNING)
)
} catch {
case t: Throwable =>
if (!t.isInstanceOf[CancelTaskException]) {
log.error("Could not instantiate task with execution ID " + executionID, t)
}
try {
if (task != null) {
task.failExternally(t)
removeAllTaskResources(task)
}
libraryCacheManager foreach { _.unregisterTask(jobID, executionID) }
} catch {
case t: Throwable => log.error("Error during cleanup of task deployment.", t)
}
self ! UpdateTaskExecutionState(
new TaskExecutionState(jobID, executionID, ExecutionState.FAILED, t)
)
// all good, we kick off the task, which performs its own initialization
task.startTaskThread()
sender ! Acknowledge
}
catch {
case t: Throwable =>
log.error("SubmitTask failed", t)
sender ! Failure(t)
}
}
......@@ -927,19 +808,20 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
val errors: Seq[String] = partitionInfos.flatMap { info =>
val (resultID, partitionInfo) = info
val reader = task.getEnvironment.getInputGateById(resultID)
val reader = task.getInputGateById(resultID)
if (reader != null) {
Future {
try {
reader.updateInputChannel(partitionInfo)
} catch {
}
catch {
case t: Throwable =>
log.error(s"Could not update input data location for task " +
s"${task.getTaskName}. Trying to fail task.", t)
try {
task.markFailed(t)
task.failExternally(t)
}
catch {
case t: Throwable =>
......@@ -977,20 +859,20 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
private def cancelAndClearEverything(cause: Throwable) {
if (runningTasks.size > 0) {
log.info("Cancelling all computations and discarding all cached data.")
for (t <- runningTasks.values().asScala) {
t.failExternally(cause)
unregisterTaskAndNotifyFinalState(t.getExecutionId)
}
runningTasks.clear()
}
}
private def unregisterTaskAndNotifyFinalState(executionID: ExecutionAttemptID): Unit = {
Option(runningTasks.remove(executionID)) match {
case Some(task) =>
// mark the task as failed if it is not yet in a final state
val task = runningTasks.remove(executionID)
if (task != null) {
// the task must be in a terminal state
if (!task.getExecutionState.isTerminal) {
try {
task.failExternally(new Exception("Task is being removed from TaskManager"))
......@@ -999,66 +881,15 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
}
}
log.info(s"Unregister task with execution ID $executionID.")
removeAllTaskResources(task)
libraryCacheManager foreach { _.unregisterTask(task.getJobID, executionID) }
log.info(s"Updating FINAL execution state of ${task.getTaskName} " +
s"(${task.getExecutionId}) to ${task.getExecutionState}.")
log.info(s"Unregistering task and sending final execution state " +
s"${task.getExecutionState} to JobManager for task ${task.getTaskName} " +
s"(${task.getExecutionId})")
self ! UpdateTaskExecutionState(new TaskExecutionState(
task.getJobID, task.getExecutionId, task.getExecutionState, task.getFailureCause))
case None =>
log.debug(s"Cannot find task with ID $executionID to unregister.")
}
}
/**
* This method cleans up the resources of a task in the distributed cache,
* network stack and the memory manager.
*
* If the cleanup in the network stack or memory manager fails, this is considered
* a fatal problem (critical resource leak) and causes the TaskManager to quit.
* A TaskManager JVM restart is the best safe way to fix that error.
*
* @param task The Task whose resources should be cleared.
*/
private def removeAllTaskResources(task: Task): Unit = {
// release the critical things first, and fail fatally if it does not work
// this releases all task resources, like buffer pools and intermediate result
// partitions being built. If this fails, the TaskManager is in serious trouble,
// as this is a massive resource leak. We kill the TaskManager in that case,
// to recover through a clean JVM start
try {
network.unregisterTask(task)
} catch {
case t: Throwable =>
killTaskManagerFatal("Failed to unregister task resources from network stack", t)
}
// safety net to release all the task's memory
try {
task.unregisterMemoryManager(memoryManager)
} catch {
case t: Throwable =>
killTaskManagerFatal("Failed to unregister task memory from memory manager", t)
}
// release temp files from the distributed cache
if (task.getEnvironment != null) {
try {
for (entry <- DistributedCache.readFileInfoFromConfig(
task.getEnvironment.getJobConfiguration).asScala) {
fileCache.deleteTmpFile(entry.getKey, entry.getValue, task.getJobID)
}
} catch {
// this is pretty unpleasant, but not a reason to give up immediately
case e: Exception => log.error(
"Error cleaning up local temp files from the distributed cache.", e)
}
else {
log.error(s"Cannot find task with ID $executionID to unregister.")
}
}
......
......@@ -19,8 +19,8 @@
package org.apache.flink.runtime.io.network.partition.consumer;
import com.google.common.collect.Lists;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
......@@ -36,6 +36,7 @@ import org.apache.flink.runtime.io.network.util.TestPartitionProducer;
import org.apache.flink.runtime.io.network.util.TestProducerSource;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.junit.Test;
import java.util.Collections;
......@@ -93,7 +94,7 @@ public class LocalInputChannelTest {
partitionIds[i] = new ResultPartitionID();
final ResultPartition partition = new ResultPartition(
mock(Environment.class),
"Test Name",
jobId,
partitionIds[i],
ResultPartitionType.PIPELINED,
......@@ -222,7 +223,7 @@ public class LocalInputChannelTest {
checkArgument(numberOfExpectedBuffersPerChannel >= 1);
this.inputGate = new SingleInputGate(
mock(Environment.class),
"Test Name",
new IntermediateDataSetID(),
subpartitionIndex,
numberOfInputChannels);
......
......@@ -22,8 +22,6 @@ import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionLocation;
import org.apache.flink.runtime.event.task.TaskEvent;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.execution.RuntimeEnvironment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
......@@ -61,7 +59,7 @@ public class SingleInputGateTest {
public void testBasicGetNextLogic() throws Exception {
// Setup
final SingleInputGate inputGate = new SingleInputGate(
mock(Environment.class), new IntermediateDataSetID(), 0, 2);
"Test Task Name", new IntermediateDataSetID(), 0, 2);
final TestInputChannel[] inputChannels = new TestInputChannel[]{
new TestInputChannel(inputGate, 0),
......@@ -107,7 +105,7 @@ public class SingleInputGateTest {
// Setup reader with one local and one unknown input channel
final IntermediateDataSetID resultId = new IntermediateDataSetID();
final SingleInputGate inputGate = new SingleInputGate(mock(Environment.class), resultId, 0, 2);
final SingleInputGate inputGate = new SingleInputGate("Test Task Name", resultId, 0, 2);
final BufferPool bufferPool = mock(BufferPool.class);
when(bufferPool.getNumberOfRequiredMemorySegments()).thenReturn(2);
......
......@@ -18,7 +18,6 @@
package org.apache.flink.runtime.io.network.partition.consumer;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
......@@ -30,7 +29,6 @@ import java.util.List;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkElementIndex;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
/**
......@@ -50,7 +48,7 @@ public class TestSingleInputGate {
checkArgument(numberOfInputChannels >= 1);
this.inputGate = spy(new SingleInputGate(
mock(Environment.class), new IntermediateDataSetID(), 0, numberOfInputChannels));
"Test Task Name", new IntermediateDataSetID(), 0, numberOfInputChannels));
this.inputChannels = new TestInputChannel[numberOfInputChannels];
......
......@@ -18,14 +18,13 @@
package org.apache.flink.runtime.io.network.partition.consumer;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
public class UnionInputGateTest {
......@@ -39,9 +38,9 @@ public class UnionInputGateTest {
@Test(timeout = 120 * 1000)
public void testBasicGetNextLogic() throws Exception {
// Setup
final Environment env = mock(Environment.class);
final SingleInputGate ig1 = new SingleInputGate(env, new IntermediateDataSetID(), 0, 3);
final SingleInputGate ig2 = new SingleInputGate(env, new IntermediateDataSetID(), 0, 5);
final String testTaskName = "Test Task";
final SingleInputGate ig1 = new SingleInputGate(testTaskName, new IntermediateDataSetID(), 0, 3);
final SingleInputGate ig2 = new SingleInputGate(testTaskName, new IntermediateDataSetID(), 0, 5);
final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});
......
......@@ -18,6 +18,7 @@
package org.apache.flink.runtime.operators.testutils;
import akka.actor.ActorRef;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
......@@ -261,4 +262,9 @@ public class MockEnvironment implements Environment {
public void reportAccumulators(Map<String, Accumulator<?, ?>> accumulators) {
// discard, this is only for testing
}
@Override
public ActorRef getJobManager() {
return ActorRef.noSender();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.taskmanager;
import akka.actor.UntypedActor;
import java.util.concurrent.BlockingQueue;
/**
* Actor for testing that simply puts all its messages into a
* blocking queue.
*/
class ForwardingActor extends UntypedActor {
private final BlockingQueue<Object> queue;
public ForwardingActor(BlockingQueue<Object> queue) {
this.queue = queue;
}
@Override
public void onReceive(Object message) {
queue.add(message);
}
}
......@@ -21,6 +21,8 @@ package org.apache.flink.runtime.taskmanager;
import static org.junit.Assert.*;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
......@@ -84,4 +86,35 @@ public class TaskExecutionStateTest {
fail(e.getMessage());
}
}
@Test
public void hanleNonSerializableException() {
try {
@SuppressWarnings({"ThrowableInstanceNeverThrown", "serial"})
Exception hostile = new Exception() {
// should be non serializable, because it contains the outer class reference
@Override
public String getMessage() {
throw new RuntimeException("Cannot get Message");
}
@Override
public void printStackTrace(PrintStream s) {
throw new RuntimeException("Cannot print");
}
@Override
public void printStackTrace(PrintWriter s) {
throw new RuntimeException("Cannot print");
}
};
new TaskExecutionState(new JobID(), new ExecutionAttemptID(), ExecutionState.FAILED, hostile);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
}
......@@ -29,6 +29,7 @@ import akka.testkit.JavaTestKit;
import akka.util.Timeout;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.akka.AkkaUtils;
import org.apache.flink.runtime.blob.BlobKey;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
......@@ -56,11 +57,14 @@ import org.apache.flink.runtime.messages.TaskMessages.SubmitTask;
import org.apache.flink.runtime.messages.TaskMessages.TaskOperationResult;
import org.apache.flink.runtime.testingUtils.TestingTaskManager;
import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
import org.apache.flink.runtime.testingUtils.TestingUtils;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.concurrent.Await;
import scala.concurrent.Future;
......@@ -75,42 +79,61 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
@SuppressWarnings("serial")
public class TaskManagerTest {
private static ActorSystem system;
private static Timeout timeout = new Timeout(1, TimeUnit.MINUTES);
private static final Logger LOG = LoggerFactory.getLogger(TaskManagerTest.class);
private static final Timeout timeout = new Timeout(1, TimeUnit.MINUTES);
private final FiniteDuration d = new FiniteDuration(20, TimeUnit.SECONDS);
private static final FiniteDuration d = new FiniteDuration(20, TimeUnit.SECONDS);
private static ActorSystem system;
@BeforeClass
public static void setup() {
system = ActorSystem.create("TestActorSystem", TestingUtils.testConfig());
system = AkkaUtils.createLocalActorSystem(new Configuration());
}
@AfterClass
public static void teardown() {
JavaTestKit.shutdownActorSystem(system);
system = null;
}
@Test
public void testSetupTaskManager() {
public void testSubmitAndExecuteTask() {
LOG.info( "--------------------------------------------------------------------\n" +
" Starting testSubmitAndExecuteTask() \n" +
"--------------------------------------------------------------------");
new JavaTestKit(system){{
ActorRef jobManager = null;
ActorRef taskManager = null;
try {
jobManager = system.actorOf(Props.create(SimpleJobManager.class));
taskManager = createTaskManager(jobManager);
taskManager = createTaskManager(getTestActor(), false);
final ActorRef tmClosure = taskManager;
// handle the registration
new Within(d) {
@Override
protected void run() {
expectMsgClass(RegistrationMessages.RegisterTaskManager.class);
final InstanceID iid = new InstanceID();
assertEquals(tmClosure, getLastSender());
tmClosure.tell(new RegistrationMessages.AcknowledgeRegistration(
getTestActor(), iid, 12345), getTestActor());
}
};
JobID jid = new JobID();
JobVertexID vid = new JobVertexID();
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7,
......@@ -119,13 +142,54 @@ public class TaskManagerTest {
Collections.<InputGateDeploymentDescriptor>emptyList(),
new ArrayList<BlobKey>(), 0);
final ActorRef tmClosure = taskManager;
new Within(d) {
@Override
protected void run() {
tmClosure.tell(new SubmitTask(tdd), getRef());
expectMsgEquals(Messages.getAcknowledge());
// TaskManager should acknowledge the submission
// heartbeats may be interleaved
long deadline = System.currentTimeMillis() + 10000;
do {
Object message = receiveOne(d);
if (message == Messages.getAcknowledge()) {
break;
}
} while (System.currentTimeMillis() < deadline);
// task should have switched to running
Object toRunning = new TaskMessages.UpdateTaskExecutionState(
new TaskExecutionState(jid, eid, ExecutionState.RUNNING));
// task should have switched to finished
Object toFinished = new TaskMessages.UpdateTaskExecutionState(
new TaskExecutionState(jid, eid, ExecutionState.FINISHED));
deadline = System.currentTimeMillis() + 10000;
do {
Object message = receiveOne(d);
if (message.equals(toRunning)) {
break;
}
else if (!(message instanceof TaskManagerMessages.Heartbeat)) {
fail("Unexpected message: " + message);
}
} while (System.currentTimeMillis() < deadline);
deadline = System.currentTimeMillis() + 10000;
do {
Object message = receiveOne(d);
if (message.equals(toFinished)) {
break;
}
else if (!(message instanceof TaskManagerMessages.Heartbeat)) {
fail("Unexpected message: " + message);
}
} while (System.currentTimeMillis() < deadline);
}
};
}
......@@ -138,22 +202,24 @@ public class TaskManagerTest {
if (taskManager != null) {
taskManager.tell(Kill.getInstance(), ActorRef.noSender());
}
if (jobManager != null) {
jobManager.tell(Kill.getInstance(), ActorRef.noSender());
}
}
}};
}
@Test
public void testJobSubmissionAndCanceling() {
LOG.info( "--------------------------------------------------------------------\n" +
" Starting testJobSubmissionAndCanceling() \n" +
"--------------------------------------------------------------------");
new JavaTestKit(system){{
ActorRef jobManager = null;
ActorRef taskManager = null;
try {
jobManager = system.actorOf(Props.create(SimpleJobManager.class));
taskManager = createTaskManager(jobManager);
taskManager = createTaskManager(jobManager, true);
final JobID jid1 = new JobID();
final JobID jid2 = new JobID();
......@@ -274,6 +340,11 @@ public class TaskManagerTest {
@Test
public void testGateChannelEdgeMismatch() {
LOG.info( "--------------------------------------------------------------------\n" +
" Starting testGateChannelEdgeMismatch() \n" +
"--------------------------------------------------------------------");
new JavaTestKit(system){{
ActorRef jobManager = null;
......@@ -281,7 +352,7 @@ public class TaskManagerTest {
try {
jobManager = system.actorOf(Props.create(SimpleJobManager.class));
taskManager = createTaskManager(jobManager);
taskManager = createTaskManager(jobManager, true);
final ActorRef tm = taskManager;
final JobID jid = new JobID();
......@@ -353,6 +424,11 @@ public class TaskManagerTest {
@Test
public void testRunJobWithForwardChannel() {
LOG.info( "--------------------------------------------------------------------\n" +
" Starting testRunJobWithForwardChannel() \n" +
"--------------------------------------------------------------------");
new JavaTestKit(system){{
ActorRef jobManager = null;
......@@ -368,7 +444,7 @@ public class TaskManagerTest {
jobManager = system.actorOf(Props.create(new SimpleLookupJobManagerCreator()));
taskManager = createTaskManager(jobManager);
taskManager = createTaskManager(jobManager, true);
final ActorRef tm = taskManager;
IntermediateResultPartitionID partitionId = new IntermediateResultPartitionID();
......@@ -470,6 +546,10 @@ public class TaskManagerTest {
@Test
public void testCancellingDependentAndStateUpdateFails() {
LOG.info( "--------------------------------------------------------------------\n" +
" Starting testCancellingDependentAndStateUpdateFails() \n" +
"--------------------------------------------------------------------");
// this tests creates two tasks. the sender sends data, and fails to send the
// state update back to the job manager
// the second one blocks to be canceled
......@@ -491,7 +571,7 @@ public class TaskManagerTest {
new SimpleLookupFailingUpdateJobManagerCreator(eid2)
)
);
taskManager = createTaskManager(jobManager);
taskManager = createTaskManager(jobManager, true);
final ActorRef tm = taskManager;
IntermediateResultPartitionID partitionId = new IntermediateResultPartitionID();
......@@ -676,7 +756,7 @@ public class TaskManagerTest {
}
}
public static ActorRef createTaskManager(ActorRef jobManager) {
public static ActorRef createTaskManager(ActorRef jobManager, boolean waitForRegistration) {
ActorRef taskManager = null;
try {
Configuration cfg = new Configuration();
......@@ -695,16 +775,18 @@ public class TaskManagerTest {
fail("Could not create test TaskManager: " + e.getMessage());
}
Future<Object> response = Patterns.ask(taskManager,
TaskManagerMessages.getNotifyWhenRegisteredAtJobManagerMessage(), timeout);
try {
FiniteDuration d = new FiniteDuration(100, TimeUnit.SECONDS);
Await.ready(response, d);
}
catch (Exception e) {
e.printStackTrace();
fail("Exception while waiting for the task manager registration: " + e.getMessage());
if (waitForRegistration) {
Future<Object> response = Patterns.ask(taskManager,
TaskManagerMessages.getNotifyWhenRegisteredAtJobManagerMessage(), timeout);
try {
FiniteDuration d = new FiniteDuration(100, TimeUnit.SECONDS);
Await.ready(response, d);
}
catch (Exception e) {
e.printStackTrace();
fail("Exception while waiting for the task manager registration: " + e.getMessage());
}
}
return taskManager;
......
......@@ -19,72 +19,277 @@
package org.apache.flink.runtime.taskmanager;
import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Kill;
import akka.actor.Props;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.akka.AkkaUtils;
import org.apache.flink.runtime.blob.BlobKey;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.RuntimeEnvironment;
import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.MockNetworkEnvironment;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.runtime.messages.TaskMessages;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import scala.concurrent.duration.FiniteDuration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for the Task, which make sure that correct state transitions happen,
* and failures are correctly handled.
*
* All tests here have a set of mock actors for TaskManager, JobManager, and
* execution listener, which simply put the messages in a queue to be picked
* up by the test and validated.
*/
public class TaskTest {
private static ActorSystem actorSystem;
private static OneShotLatch awaitLatch;
private static OneShotLatch triggerLatch;
private ActorRef taskManagerMock;
private ActorRef jobManagerMock;
private ActorRef listenerActor;
private BlockingQueue<Object> taskManagerMessages;
private BlockingQueue<Object> jobManagerMessages;
private BlockingQueue<Object> listenerMessages;
// ------------------------------------------------------------------------
// Init & Shutdown
// ------------------------------------------------------------------------
@BeforeClass
public static void startActorSystem() {
actorSystem = AkkaUtils.createLocalActorSystem(new Configuration());
}
@AfterClass
public static void shutdown() {
actorSystem.shutdown();
actorSystem.awaitTermination();
}
@Before
public void createQueuesAndActors() {
taskManagerMessages = new LinkedBlockingQueue<Object>();
jobManagerMessages = new LinkedBlockingQueue<Object>();
listenerMessages = new LinkedBlockingQueue<Object>();
taskManagerMock = actorSystem.actorOf(Props.create(ForwardingActor.class, taskManagerMessages));
jobManagerMock = actorSystem.actorOf(Props.create(ForwardingActor.class, jobManagerMessages));
listenerActor = actorSystem.actorOf(Props.create(ForwardingActor.class, listenerMessages));
awaitLatch = new OneShotLatch();
triggerLatch = new OneShotLatch();
}
@After
public void clearActorsAndMessages() {
jobManagerMessages = null;
taskManagerMessages = null;
listenerMessages = null;
taskManagerMock.tell(Kill.getInstance(), ActorRef.noSender());
jobManagerMock.tell(Kill.getInstance(), ActorRef.noSender());
listenerActor.tell(Kill.getInstance(), ActorRef.noSender());
}
// ------------------------------------------------------------------------
// Tests
// ------------------------------------------------------------------------
@Test
public void testTaskStates() {
public void testRegularExecution() {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
Task task = createTask(TestInvokableCorrect.class);
final RuntimeEnvironment env = mock(RuntimeEnvironment.class);
// task should be new and perfect
assertEquals(ExecutionState.CREATED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
task.setEnvironment(env);
task.registerExecutionListener(listenerActor);
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
// go into the run method. we should switch to DEPLOYING, RUNNING, then
// FINISHED, and all should be good
task.run();
// cancel
task.cancelExecution();
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
// verify final state
assertEquals(ExecutionState.FINISHED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
// cannot go into running or finished state
// verify listener messages
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FINISHED, task, false);
assertFalse(task.startExecution());
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
// make sure that the TaskManager received an message to unregister the task
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelRightAway() {
try {
Task task = createTask(TestInvokableCorrect.class);
task.cancelExecution();
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
assertFalse(task.markAsFinished());
task.run();
// verify final state
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailExternallyRightAway() {
try {
Task task = createTask(TestInvokableCorrect.class);
task.failExternally(new Exception("fail externally"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
task.run();
// verify final state
assertEquals(ExecutionState.FAILED, task.getExecutionState());
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testLibraryCacheRegistrationFailed() {
try {
Task task = createTask(TestInvokableCorrect.class, mock(LibraryCacheManager.class));
// task should be new and perfect
assertEquals(ExecutionState.CREATED, task.getExecutionState());
assertFalse(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
task.registerExecutionListener(listenerActor);
// should fail
task.run();
// verify final state
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNotNull(task.getFailureCause());
assertTrue(task.getFailureCause().getMessage().contains("classloader"));
// verify listener messages
validateListenerMessage(ExecutionState.FAILED, task, true);
// make sure that the TaskManager received an message to unregister the task
validateUnregisterTask(task.getExecutionId());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailsInNetworkRegistration() {
try {
// mock a working library cache
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
// mock a network manager that rejects registration
ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
NetworkEnvironment network = mock(NetworkEnvironment.class);
when(network.getPartitionManager()).thenReturn(partitionManager);
when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier);
when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class));
Task task = createTask(TestInvokableCorrect.class, libCache, network);
task.registerExecutionListener(listenerActor);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("buffers"));
task.markFailed(new Exception("test"));
assertTrue(ExecutionState.CANCELED == task.getExecutionState());
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
verify(task).unregisterTask();
@Test
public void testInvokableInstantiationFailed() {
try {
Task task = createTask(InvokableNonInstantiable.class);
task.registerExecutionListener(listenerActor);
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("instantiate"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
......@@ -93,48 +298,19 @@ public class TaskTest {
}
@Test
public void testTaskStartFinish() {
public void testExecutionFailsInRegisterInputOutput() {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
Thread operation = new Thread() {
@Override
public void run() {
try {
assertTrue(task.markAsFinished());
}
catch (Throwable t) {
error.set(t);
}
}
};
final RuntimeEnvironment env = mock(RuntimeEnvironment.class);
when(env.getExecutingThread()).thenReturn(operation);
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
// start the execution
task.setEnvironment(env);
task.startExecution();
// wait for the execution to be finished
operation.join();
if (error.get() != null) {
ExceptionUtils.rethrow(error.get());
}
assertEquals(ExecutionState.FINISHED, task.getExecutionState());
Task task = createTask(InvokableWithExceptionInRegisterInOut.class);
task.registerExecutionListener(listenerActor);
verify(task).unregisterTask();
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("registerInputOutput"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
......@@ -143,48 +319,22 @@ public class TaskTest {
}
@Test
public void testTaskFailesInRunning() {
public void testExecutionFailsInInvoke() {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
Thread operation = new Thread() {
@Override
public void run() {
try {
task.markFailed(new Exception("test exception message"));
}
catch (Throwable t) {
error.set(t);
}
}
};
final RuntimeEnvironment env = mock(RuntimeEnvironment.class);
when(env.getExecutingThread()).thenReturn(operation);
Task task = createTask(InvokableWithExceptionInInvoke.class);
task.registerExecutionListener(listenerActor);
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
// start the execution
task.setEnvironment(env);
task.startExecution();
// wait for the execution to be finished
operation.join();
if (error.get() != null) {
ExceptionUtils.rethrow(error.get());
}
// make sure the final state is correct and the task manager knows the changes
task.run();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
verify(task).unregisterTask();
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
......@@ -193,59 +343,180 @@ public class TaskTest {
}
@Test
public void testTaskCanceledInRunning() {
public void testCancelDuringRegisterInputOutput() {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
// latches to create a deterministic order of events
final OneShotLatch toRunning = new OneShotLatch();
final OneShotLatch afterCanceling = new OneShotLatch();
Thread operation = new Thread() {
@Override
public void run() {
try {
toRunning.trigger();
afterCanceling.await();
assertFalse(task.markAsFinished());
task.cancelingDone();
}
catch (Throwable t) {
error.set(t);
}
}
};
Task task = createTask(InvokableBlockingInRegisterInOut.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
final RuntimeEnvironment env = mock(RuntimeEnvironment.class);
when(env.getExecutingThread()).thenReturn(operation);
// wait till the task is in regInOut
awaitLatch.await();
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
task.cancelExecution();
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
triggerLatch.trigger();
// start the execution
task.setEnvironment(env);
task.startExecution();
task.getExecutingThread().join();
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
toRunning.await();
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.CANCELING, task, false);
validateListenerMessage(ExecutionState.CANCELED, task, false);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailDuringRegisterInputOutput() {
try {
Task task = createTask(InvokableBlockingInRegisterInOut.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in regInOut
awaitLatch.await();
task.failExternally(new Exception("test"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
triggerLatch.trigger();
task.getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCancelDuringInvoke() {
try {
Task task = createTask(InvokableBlockingInInvoke.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.cancelExecution();
afterCanceling.trigger();
assertTrue(task.getExecutionState() == ExecutionState.CANCELING ||
task.getExecutionState() == ExecutionState.CANCELED);
task.getExecutingThread().join();
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
// wait for the execution to be finished
operation.join();
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.CANCELING, task, false);
validateListenerMessage(ExecutionState.CANCELED, task, false);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testFailExternallyDuringInvoke() {
try {
Task task = createTask(InvokableBlockingInInvoke.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in regInOut
awaitLatch.await();
task.failExternally(new Exception("test"));
assertTrue(task.getExecutionState() == ExecutionState.FAILED);
task.getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCanceledAfterExecutionFailedInRegInOut() {
try {
Task task = createTask(InvokableWithExceptionInRegisterInOut.class);
task.registerExecutionListener(listenerActor);
task.run();
if (error.get() != null) {
ExceptionUtils.rethrow(error.get());
}
// this should not overwrite the failure state
task.cancelExecution();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("registerInputOutput"));
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testCanceledAfterExecutionFailedInInvoke() {
try {
Task task = createTask(InvokableWithExceptionInInvoke.class);
task.registerExecutionListener(listenerActor);
task.run();
// this should not overwrite the failure state
task.cancelExecution();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("test"));
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
// make sure the final state is correct and the task manager knows the changes
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
verify(task).unregisterTask();
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
......@@ -254,79 +525,205 @@ public class TaskTest {
}
@Test
public void testTaskWithEnvironment() {
public void testExecutionFailesAfterCanceling() {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7,
new Configuration(), new Configuration(), TestInvokableCorrect.class.getName(),
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
Collections.<InputGateDeploymentDescriptor>emptyList(),
new ArrayList<BlobKey>(), 0);
Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
Task task = createTask(InvokableWithExceptionOnTrigger.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.cancelExecution();
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
RuntimeEnvironment env = new RuntimeEnvironment(mock(ActorRef.class), task, tdd, getClass().getClassLoader(),
mock(MemoryManager.class), mock(IOManager.class), mock(InputSplitProvider.class),
new BroadcastVariableManager(), MockNetworkEnvironment.getMock());
// this causes an exception
triggerLatch.trigger();
task.getExecutingThread().join();
task.setEnvironment(env);
// we should still be in state canceled
assertEquals(ExecutionState.CANCELED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertNull(task.getFailureCause());
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.CANCELING, task, false);
validateListenerMessage(ExecutionState.CANCELED, task, false);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testExecutionFailesAfterTaskMarkedFailed() {
try {
Task task = createTask(InvokableWithExceptionOnTrigger.class);
task.registerExecutionListener(listenerActor);
// run the task asynchronous
task.startTaskThread();
// wait till the task is in invoke
awaitLatch.await();
task.failExternally(new Exception("external"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
// this causes an exception
triggerLatch.trigger();
task.getExecutingThread().join();
task.startExecution();
task.getEnvironment().getExecutingThread().join();
assertEquals(ExecutionState.FAILED, task.getExecutionState());
assertTrue(task.isCanceledOrFailed());
assertTrue(task.getFailureCause().getMessage().contains("external"));
assertEquals(ExecutionState.FINISHED, task.getExecutionState());
validateTaskManagerStateChange(ExecutionState.RUNNING, task, false);
validateUnregisterTask(task.getExecutionId());
verify(task).unregisterTask();
validateListenerMessage(ExecutionState.RUNNING, task, false);
validateListenerMessage(ExecutionState.FAILED, task, true);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
private Task createTask(Class<? extends AbstractInvokable> invokable) {
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
return createTask(invokable, libCache);
}
private Task createTask(Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache) {
ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
NetworkEnvironment network = mock(NetworkEnvironment.class);
when(network.getPartitionManager()).thenReturn(partitionManager);
when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier);
when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
return createTask(invokable, libCache, network);
}
@Test
public void testTaskWithEnvironmentAndException() {
private Task createTask(Class<? extends AbstractInvokable> invokable,
LibraryCacheManager libCache,
NetworkEnvironment networkEnvironment) {
TaskDeploymentDescriptor tdd = createTaskDeploymentDescriptor(invokable);
return new Task(tdd,
mock(MemoryManager.class),
mock(IOManager.class),
networkEnvironment,
mock(BroadcastVariableManager.class),
taskManagerMock, jobManagerMock,
new FiniteDuration(60, TimeUnit.SECONDS),
libCache,
mock(FileCache.class));
}
private TaskDeploymentDescriptor createTaskDeploymentDescriptor(Class<? extends AbstractInvokable> invokable) {
return new TaskDeploymentDescriptor(
new JobID(), new JobVertexID(), new ExecutionAttemptID(),
"Test Task", 0, 1,
new Configuration(), new Configuration(),
invokable.getName(),
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
Collections.<InputGateDeploymentDescriptor>emptyList(),
Collections.<BlobKey>emptyList(),
0);
}
// ------------------------------------------------------------------------
// Validation Methods
// ------------------------------------------------------------------------
private void validateUnregisterTask(ExecutionAttemptID id) {
try {
final JobID jid = new JobID();
final JobVertexID vid = new JobVertexID();
final ExecutionAttemptID eid = new ExecutionAttemptID();
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
Object rawMessage = taskManagerMessages.poll(10, TimeUnit.SECONDS);
TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7,
new Configuration(), new Configuration(), TestInvokableWithException.class.getName(),
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
Collections.<InputGateDeploymentDescriptor>emptyList(),
new ArrayList<BlobKey>(), 0);
assertNotNull("There is no additional TaskManager message", rawMessage);
assertTrue("TaskManager message is not 'UnregisterTask'", rawMessage instanceof TaskMessages.TaskInFinalState);
Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender()));
doNothing().when(task).unregisterTask();
RuntimeEnvironment env = new RuntimeEnvironment(mock(ActorRef.class), task, tdd, getClass().getClassLoader(),
mock(MemoryManager.class), mock(IOManager.class), mock(InputSplitProvider.class),
new BroadcastVariableManager(), MockNetworkEnvironment.getMock());
TaskMessages.TaskInFinalState message = (TaskMessages.TaskInFinalState) rawMessage;
assertEquals(id, message.executionID());
}
catch (InterruptedException e) {
fail("interrupted");
}
}
task.setEnvironment(env);
private void validateTaskManagerStateChange(ExecutionState state, Task task, boolean hasError) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
Object rawMessage = taskManagerMessages.poll(10, TimeUnit.SECONDS);
assertNotNull("There is no additional TaskManager message", rawMessage);
assertTrue("TaskManager message is not 'UpdateTaskExecutionState'",
rawMessage instanceof TaskMessages.UpdateTaskExecutionState);
assertEquals(ExecutionState.DEPLOYING, task.getExecutionState());
TaskMessages.UpdateTaskExecutionState message =
(TaskMessages.UpdateTaskExecutionState) rawMessage;
task.startExecution();
task.getEnvironment().getExecutingThread().join();
TaskExecutionState taskState = message.taskExecutionState();
assertEquals(task.getJobID(), taskState.getJobID());
assertEquals(task.getExecutionId(), taskState.getID());
assertEquals(state, taskState.getExecutionState());
if (hasError) {
assertNotNull(taskState.getError(getClass().getClassLoader()));
} else {
assertNull(taskState.getError(getClass().getClassLoader()));
}
}
catch (InterruptedException e) {
fail("interrupted");
}
}
private void validateListenerMessage(ExecutionState state, Task task, boolean hasError) {
try {
// we may have to wait for a bit to give the actors time to receive the message
// and put it into the queue
TaskMessages.UpdateTaskExecutionState message =
(TaskMessages.UpdateTaskExecutionState) listenerMessages.poll(10, TimeUnit.SECONDS);
assertNotNull("There is no additional listener message", message);
assertEquals(ExecutionState.FAILED, task.getExecutionState());
TaskExecutionState taskState = message.taskExecutionState();
verify(task).unregisterTask();
assertEquals(task.getJobID(), taskState.getJobID());
assertEquals(task.getExecutionId(), taskState.getID());
assertEquals(state, taskState.getExecutionState());
if (hasError) {
assertNotNull(taskState.getError(getClass().getClassLoader()));
} else {
assertNull(taskState.getError(getClass().getClassLoader()));
}
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
catch (InterruptedException e) {
fail("interrupted");
}
}
// --------------------------------------------------------------------------------------------
// Mock invokable code
// --------------------------------------------------------------------------------------------
public static final class TestInvokableCorrect extends AbstractInvokable {
......@@ -336,16 +733,93 @@ public class TaskTest {
@Override
public void invoke() {}
@Override
public void cancel() throws Exception {
fail("This should not be called");
}
}
public static final class InvokableWithExceptionInRegisterInOut extends AbstractInvokable {
@Override
public void registerInputOutput() {
throw new RuntimeException("test");
}
@Override
public void invoke() {}
}
public static final class TestInvokableWithException extends AbstractInvokable {
public static final class InvokableWithExceptionInInvoke extends AbstractInvokable {
@Override
public void registerInputOutput() {}
@Override
public void invoke() throws Exception {
throw new Exception("test");
}
}
public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable {
@Override
public void registerInputOutput() {}
@Override
public void invoke() {
awaitLatch.trigger();
// make sure that the interrupt call does not
// grab us out of the lock early
while (true) {
try {
triggerLatch.await();
break;
}
catch (InterruptedException e) {
// fall through the loop
}
}
throw new RuntimeException("test");
}
}
public static abstract class InvokableNonInstantiable extends AbstractInvokable {}
public static final class InvokableBlockingInRegisterInOut extends AbstractInvokable {
@Override
public void registerInputOutput() {
awaitLatch.trigger();
try {
triggerLatch.await();
}
catch (InterruptedException e) {
throw new RuntimeException();
}
}
@Override
public void invoke() {}
}
public static final class InvokableBlockingInInvoke extends AbstractInvokable {
@Override
public void registerInputOutput() {}
@Override
public void invoke() throws Exception {
throw new Exception("test exception");
awaitLatch.trigger();
// block forever
synchronized (this) {
wait();
}
}
}
}
......@@ -27,7 +27,7 @@ import org.apache.flink.runtime.io.disk.iomanager.IOManager
import org.apache.flink.runtime.io.network.NetworkEnvironment
import org.apache.flink.runtime.memorymanager.DefaultMemoryManager
import org.apache.flink.runtime.messages.Messages.Disconnect
import org.apache.flink.runtime.messages.TaskMessages.{UpdateTaskExecutionState, UnregisterTask}
import org.apache.flink.runtime.messages.TaskMessages.{TaskInFinalState, UpdateTaskExecutionState}
import org.apache.flink.runtime.taskmanager.{TaskManagerConfiguration, TaskManager}
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobRemoved
import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect
......@@ -95,8 +95,8 @@ class TestingTaskManager(config: TaskManagerConfiguration,
}
}
case UnregisterTask(executionID) =>
super.receiveWithLogMessages(UnregisterTask(executionID))
case TaskInFinalState(executionID) =>
super.receiveWithLogMessages(TaskInFinalState(executionID))
waitForRemoval.remove(executionID) match {
case Some(actors) => for(actor <- actors) actor ! true
case None =>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册