diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 7ab3bc90e67f29ef97d2000bb2addae837e70e63..081e3cae362e9722140bbb99f1f88df3cca15904 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.execution; +import akka.actor.ActorRef; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; @@ -159,4 +160,6 @@ public interface Environment { InputGate[] getAllInputGates(); + // this should go away + ActorRef getJobManager(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/ExecutionState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/ExecutionState.java index 2fcaea16cb32d48e4f905060156dc6264146acab..9f4a5a71307b43ac0361c88a5f2f2624e370d82b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/ExecutionState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/ExecutionState.java @@ -35,10 +35,10 @@ package org.apache.flink.runtime.execution; * ... -> FAILED * * - * It is possible to enter the {@code FAILED} state from any other state. + *

It is possible to enter the {@code FAILED} state from any other state.

* - * The states {@code FINISHED}, {@code CANCELED}, and {@code FAILED} are - * considered terminal states. + *

The states {@code FINISHED}, {@code CANCELED}, and {@code FAILED} are + * considered terminal states.

*/ public enum ExecutionState { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java deleted file mode 100644 index 081d4bc8d4acee21cfa290682a17426fa641b7f2..0000000000000000000000000000000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java +++ /dev/null @@ -1,458 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.execution; - -import akka.actor.ActorRef; -import org.apache.flink.api.common.accumulators.Accumulator; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.accumulators.AccumulatorEvent; -import org.apache.flink.runtime.broadcast.BroadcastVariableManager; -import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; -import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; -import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; -import org.apache.flink.runtime.io.network.partition.ResultPartition; -import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.consumer.InputGate; -import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.memorymanager.MemoryManager; -import org.apache.flink.runtime.messages.accumulators.ReportAccumulatorResult; -import org.apache.flink.runtime.taskmanager.Task; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.FutureTask; -import java.util.concurrent.atomic.AtomicBoolean; - -import static com.google.common.base.Preconditions.checkElementIndex; -import static com.google.common.base.Preconditions.checkNotNull; - -public class RuntimeEnvironment implements Environment, Runnable { - - private static final Logger LOG = LoggerFactory.getLogger(RuntimeEnvironment.class); - - private static final ThreadGroup TASK_THREADS = new ThreadGroup("Task Threads"); - - /** The ActorRef to the job manager */ - private final ActorRef jobManager; - - /** The task that owns this environment */ - private final Task owner; - - /** The job configuration encapsulated in the environment object. */ - private final Configuration jobConfiguration; - - /** The task configuration encapsulated in the environment object. */ - private final Configuration taskConfiguration; - - /** ClassLoader for all user code classes */ - private final ClassLoader userCodeClassLoader; - - /** Instance of the class to be run in this environment. */ - private final AbstractInvokable invokable; - - /** The memory manager of the current environment (currently the one associated with the executing TaskManager). */ - private final MemoryManager memoryManager; - - /** The I/O manager of the current environment (currently the one associated with the executing TaskManager). */ - private final IOManager ioManager; - - /** The input split provider that can be queried for new input splits. */ - private final InputSplitProvider inputSplitProvider; - - /** The thread executing the task in the environment. */ - private Thread executingThread; - - private final BroadcastVariableManager broadcastVariableManager; - - private final Map> cacheCopyTasks = new HashMap>(); - - private final AtomicBoolean canceled = new AtomicBoolean(); - - private final ResultPartition[] producedPartitions; - private final ResultPartitionWriter[] writers; - - private final SingleInputGate[] inputGates; - - private final Map inputGatesById = new HashMap(); - - public RuntimeEnvironment( - ActorRef jobManager, Task owner, TaskDeploymentDescriptor tdd, ClassLoader userCodeClassLoader, - MemoryManager memoryManager, IOManager ioManager, InputSplitProvider inputSplitProvider, - BroadcastVariableManager broadcastVariableManager, NetworkEnvironment networkEnvironment) throws Exception { - - this.owner = checkNotNull(owner); - - this.memoryManager = checkNotNull(memoryManager); - this.ioManager = checkNotNull(ioManager); - this.inputSplitProvider = checkNotNull(inputSplitProvider); - this.jobManager = checkNotNull(jobManager); - - this.broadcastVariableManager = checkNotNull(broadcastVariableManager); - - try { - // Produced intermediate result partitions - final List partitions = tdd.getProducedPartitions(); - - this.producedPartitions = new ResultPartition[partitions.size()]; - this.writers = new ResultPartitionWriter[partitions.size()]; - - for (int i = 0; i < this.producedPartitions.length; i++) { - ResultPartitionDeploymentDescriptor desc = partitions.get(i); - ResultPartitionID partitionId = new ResultPartitionID(desc.getPartitionId(), owner.getExecutionId()); - - this.producedPartitions[i] = new ResultPartition( - this, - owner.getJobID(), - partitionId, - desc.getPartitionType(), - desc.getNumberOfSubpartitions(), - networkEnvironment.getPartitionManager(), - networkEnvironment.getPartitionConsumableNotifier(), - ioManager, - networkEnvironment.getDefaultIOMode()); - - writers[i] = new ResultPartitionWriter(this.producedPartitions[i]); - } - - // Consumed intermediate result partitions - final List consumedPartitions = tdd.getInputGates(); - - this.inputGates = new SingleInputGate[consumedPartitions.size()]; - - for (int i = 0; i < inputGates.length; i++) { - inputGates[i] = SingleInputGate.create( - this, consumedPartitions.get(i), networkEnvironment); - - // The input gates are organized by key for task updates/channel updates at runtime - inputGatesById.put(inputGates[i].getConsumedResultId(), inputGates[i]); - } - - this.jobConfiguration = tdd.getJobConfiguration(); - this.taskConfiguration = tdd.getTaskConfiguration(); - - // ---------------------------------------------------------------- - // Invokable setup - // ---------------------------------------------------------------- - // Note: This has to be done *after* the readers and writers have - // been setup, because the invokable relies on them for I/O. - // ---------------------------------------------------------------- - - // Load and instantiate the invokable class - this.userCodeClassLoader = checkNotNull(userCodeClassLoader); - // Class of the task to run in this environment - Class invokableClass; - try { - final String className = tdd.getInvokableClassName(); - invokableClass = Class.forName(className, true, userCodeClassLoader).asSubclass(AbstractInvokable.class); - } - catch (Throwable t) { - throw new Exception("Could not load invokable class.", t); - } - - try { - this.invokable = invokableClass.newInstance(); - } - catch (Throwable t) { - throw new Exception("Could not instantiate the invokable class.", t); - } - - this.invokable.setEnvironment(this); - this.invokable.registerInputOutput(); - } - catch (Throwable t) { - throw new Exception("Error setting up runtime environment: " + t.getMessage(), t); - } - } - - /** - * Returns the task invokable instance. - */ - public AbstractInvokable getInvokable() { - return this.invokable; - } - - @Override - public JobID getJobID() { - return this.owner.getJobID(); - } - - @Override - public JobVertexID getJobVertexId() { - return this.owner.getVertexID(); - } - - @Override - public void run() { - // quick fail in case the task was cancelled while the thread was started - if (owner.isCanceledOrFailed()) { - owner.cancelingDone(); - return; - } - - try { - Thread.currentThread().setContextClassLoader(userCodeClassLoader); - invokable.invoke(); - - // Make sure, we enter the catch block when the task has been canceled - if (owner.isCanceledOrFailed()) { - throw new CancelTaskException("Task has been canceled or failed"); - } - - // Finish the produced partitions - if (producedPartitions != null) { - for (ResultPartition partition : producedPartitions) { - if (partition != null) { - partition.finish(); - } - } - } - - if (owner.isCanceledOrFailed()) { - throw new CancelTaskException(); - } - - // Finally, switch execution state to FINISHED and report to job manager - if (!owner.markAsFinished()) { - throw new Exception("Could *not* notify job manager that the task is finished."); - } - } - catch (Throwable t) { - if (!owner.isCanceledOrFailed()) { - // Perform clean up when the task failed and has been not canceled by the user - try { - invokable.cancel(); - } - catch (Throwable t2) { - LOG.error("Error while canceling the task", t2); - } - } - - // if we are already set as cancelled or failed (when failure is triggered externally), - // mark that the thread is done. - if (owner.isCanceledOrFailed() || t instanceof CancelTaskException) { - owner.cancelingDone(); - } - else { - // failure from inside the task thread. notify the task of the failure - owner.markFailed(t); - } - } - } - - /** - * Returns the thread, which is assigned to execute the user code. - */ - public Thread getExecutingThread() { - synchronized (this) { - if (executingThread == null) { - String name = owner.getTaskNameWithSubtasks(); - - if (LOG.isDebugEnabled()) { - name = name + " (" + owner.getExecutionId() + ")"; - } - - executingThread = new Thread(TASK_THREADS, this, name); - } - - return executingThread; - } - } - - public void cancelExecution() { - if (!canceled.compareAndSet(false, true)) { - return; - } - - LOG.info("Canceling {} ({}).", owner.getTaskNameWithSubtasks(), owner.getExecutionId()); - - // Request user code to shut down - if (invokable != null) { - try { - invokable.cancel(); - } - catch (Throwable e) { - LOG.error("Error while canceling the task.", e); - } - } - - final Thread executingThread = this.executingThread; - if (executingThread != null) { - // interrupt the running thread and wait for it to die - executingThread.interrupt(); - try { - executingThread.join(5000); - } - catch (InterruptedException e) { - } - if (!executingThread.isAlive()) { - return; - } - // Continuously interrupt the user thread until it changed to state CANCELED - while (executingThread != null && executingThread.isAlive()) { - LOG.warn("Task " + owner.getTaskNameWithSubtasks() + " did not react to cancelling signal. Sending repeated interrupt."); - if (LOG.isDebugEnabled()) { - StringBuilder bld = new StringBuilder("Task ").append(owner.getTaskNameWithSubtasks()).append(" is stuck in method:\n"); - StackTraceElement[] stack = executingThread.getStackTrace(); - for (StackTraceElement e : stack) { - bld.append(e).append('\n'); - } - LOG.debug(bld.toString()); - } - executingThread.interrupt(); - try { - executingThread.join(1000); - } - catch (InterruptedException e) { - } - } - } - } - - @Override - public ActorRef getJobManager() { - return jobManager; - } - - @Override - public IOManager getIOManager() { - return ioManager; - } - - @Override - public MemoryManager getMemoryManager() { - return memoryManager; - } - - @Override - public BroadcastVariableManager getBroadcastVariableManager() { - return broadcastVariableManager; - } - - @Override - public void reportAccumulators(Map> accumulators) { - AccumulatorEvent evt; - try { - evt = new AccumulatorEvent(getJobID(), accumulators); - } - catch (IOException e) { - throw new RuntimeException("Cannot serialize accumulators to send them to JobManager", e); - } - - ReportAccumulatorResult accResult = new ReportAccumulatorResult(getJobID(), owner.getExecutionId(), evt); - jobManager.tell(accResult, ActorRef.noSender()); - } - - @Override - public ResultPartitionWriter getWriter(int index) { - checkElementIndex(index, writers.length, "Illegal environment writer request."); - - return writers[checkElementIndex(index, writers.length)]; - } - - @Override - public ResultPartitionWriter[] getAllWriters() { - return writers; - } - - @Override - public InputGate getInputGate(int index) { - checkElementIndex(index, inputGates.length); - - return inputGates[index]; - } - - @Override - public SingleInputGate[] getAllInputGates() { - return inputGates; - } - - public ResultPartition[] getProducedPartitions() { - return producedPartitions; - } - - public SingleInputGate getInputGateById(IntermediateDataSetID id) { - return inputGatesById.get(id); - } - - @Override - public Configuration getTaskConfiguration() { - return taskConfiguration; - } - - @Override - public Configuration getJobConfiguration() { - return jobConfiguration; - } - - @Override - public int getNumberOfSubtasks() { - return owner.getNumberOfSubtasks(); - } - - @Override - public int getIndexInSubtaskGroup() { - return owner.getSubtaskIndex(); - } - - @Override - public String getTaskName() { - return owner.getTaskName(); - } - - @Override - public InputSplitProvider getInputSplitProvider() { - return inputSplitProvider; - } - - @Override - public String getTaskNameWithSubtasks() { - return owner.getTaskNameWithSubtasks(); - } - - @Override - public ClassLoader getUserClassLoader() { - return userCodeClassLoader; - } - - public void addCopyTasksForCacheFile(Map> copyTasks) { - cacheCopyTasks.putAll(copyTasks); - } - - public void addCopyTaskForCacheFile(String name, FutureTask copyTask) { - cacheCopyTasks.put(name, copyTask); - } - - @Override - public Map> getCopyTask() { - return cacheCopyTasks; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index af55ebf4de228340622143f0b0d401b0b9724baa..259ea5556507ec1b9a53110553ca97230dcc4de8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -243,7 +243,7 @@ public class NetworkEnvironment { public void registerTask(Task task) throws IOException { final ResultPartition[] producedPartitions = task.getProducedPartitions(); - final ResultPartitionWriter[] writers = task.getWriters(); + final ResultPartitionWriter[] writers = task.getAllWriters(); if (writers.length != producedPartitions.length) { throw new IllegalStateException("Unequal number of writers and partitions."); @@ -288,7 +288,7 @@ public class NetworkEnvironment { } // Setup the buffer pool for each buffer reader - final SingleInputGate[] inputGates = task.getInputGates(); + final SingleInputGate[] inputGates = task.getAllInputGates(); for (SingleInputGate gate : inputGates) { BufferPool bufferPool = null; @@ -329,10 +329,9 @@ public class NetworkEnvironment { partitionManager.releasePartitionsProducedBy(executionId); } - ResultPartitionWriter[] writers = task.getWriters(); - + ResultPartitionWriter[] writers = task.getAllWriters(); if (writers != null) { - for (ResultPartitionWriter writer : task.getWriters()) { + for (ResultPartitionWriter writer : writers) { taskEventDispatcher.unregisterWriter(writer); } } @@ -344,7 +343,7 @@ public class NetworkEnvironment { } } - final SingleInputGate[] inputGates = task.getInputGates(); + final SingleInputGate[] inputGates = task.getAllInputGates(); if (inputGates != null) { for (SingleInputGate gate : inputGates) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index f06c8fb077dc367a1715907a48a21de745c5d93c..df1f254c60292bd15b8733e7c3bfbb2201ff866c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode; @@ -76,9 +75,8 @@ import static com.google.common.base.Preconditions.checkState; public class ResultPartition implements BufferPoolOwner { private static final Logger LOG = LoggerFactory.getLogger(ResultPartition.class); - - /** The owning environment. Mainly for debug purposes. */ - private final Environment owner; + + private final String owningTaskName; private final JobID jobId; @@ -120,7 +118,7 @@ public class ResultPartition implements BufferPoolOwner { private long totalNumberOfBytes; public ResultPartition( - Environment owner, + String owningTaskName, JobID jobId, ResultPartitionID partitionId, ResultPartitionType partitionType, @@ -130,7 +128,7 @@ public class ResultPartition implements BufferPoolOwner { IOManager ioManager, IOMode defaultIoMode) { - this.owner = checkNotNull(owner); + this.owningTaskName = checkNotNull(owningTaskName); this.jobId = checkNotNull(jobId); this.partitionId = checkNotNull(partitionId); this.partitionType = checkNotNull(partitionType); @@ -162,7 +160,7 @@ public class ResultPartition implements BufferPoolOwner { // Initially, partitions should be consumed once before release. pin(); - LOG.debug("{}: Initialized {}", owner.getTaskNameWithSubtasks(), this); + LOG.debug("{}: Initialized {}", owningTaskName, this); } /** @@ -281,7 +279,7 @@ public class ResultPartition implements BufferPoolOwner { */ public void release() { if (isReleased.compareAndSet(false, true)) { - LOG.debug("{}: Releasing {}.", owner.getTaskNameWithSubtasks(), this); + LOG.debug("{}: Releasing {}.", owningTaskName, this); // Release all subpartitions for (ResultSubpartition subpartition : subpartitions) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index b0d138a8a07ca2de1e3464fe1baa20ff4a208f91..acda1d82c9e1bbb58422f8a903a29c3d88b9dfd2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -24,7 +24,6 @@ import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; import org.apache.flink.runtime.event.task.AbstractEvent; import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; @@ -101,8 +100,8 @@ public class SingleInputGate implements InputGate { /** Lock object to guard partition requests and runtime channel updates. */ private final Object requestLock = new Object(); - /** The owning environment. Mainly for debug purposes. */ - private final Environment owner; + /** The name of the owning task, for logging purposes. */ + private final String owningTaskName; /** * The ID of the consumed intermediate result. Each input gate consumes partitions of the @@ -153,12 +152,12 @@ public class SingleInputGate implements InputGate { private int numberOfUninitializedChannels; public SingleInputGate( - Environment owner, + String owningTaskName, IntermediateDataSetID consumedResultId, int consumedSubpartitionIndex, int numberOfInputChannels) { - this.owner = checkNotNull(owner); + this.owningTaskName = checkNotNull(owningTaskName); this.consumedResultId = checkNotNull(consumedResultId); checkArgument(consumedSubpartitionIndex >= 0); @@ -265,7 +264,7 @@ public class SingleInputGate implements InputGate { synchronized (requestLock) { if (!isReleased) { try { - LOG.debug("{}: Releasing {}.", owner.getTaskNameWithSubtasks(), this); + LOG.debug("{}: Releasing {}.", owningTaskName, this); for (InputChannel inputChannel : inputChannels.values()) { try { @@ -410,7 +409,7 @@ public class SingleInputGate implements InputGate { * Creates an input gate and all of its input channels. */ public static SingleInputGate create( - Environment owner, + String owningTaskName, InputGateDeploymentDescriptor igdd, NetworkEnvironment networkEnvironment) { @@ -422,7 +421,7 @@ public class SingleInputGate implements InputGate { final InputChannelDeploymentDescriptor[] icdd = checkNotNull(igdd.getInputChannelDeploymentDescriptors()); final SingleInputGate inputGate = new SingleInputGate( - owner, consumedResultId, consumedSubpartitionIndex, icdd.length); + owningTaskName, consumedResultId, consumedSubpartitionIndex, icdd.length); // Create the input channels. There is one input channel for each consumed partition. final InputChannel[] inputChannels = new InputChannel[icdd.length]; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java index b528f759cf02059714c933b0df6d7a06c3d3d172..2bee0940cfc41e40819a9390dc31422a8497a7e3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java @@ -1067,7 +1067,8 @@ public class RegularPactTask extends AbstractInvokable i public DistributedRuntimeUDFContext createRuntimeContext(String taskName) { Environment env = getEnvironment(); return new DistributedRuntimeUDFContext(taskName, env.getNumberOfSubtasks(), - env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(), env.getCopyTask()); + env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(), + env.getDistributedCacheEntries()); } // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java new file mode 100644 index 0000000000000000000000000000000000000000..1321336bff44f27bd85b1f4192687238e173028f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import akka.actor.ActorRef; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.accumulators.AccumulatorEvent; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; +import org.apache.flink.runtime.memorymanager.MemoryManager; +import org.apache.flink.runtime.messages.accumulators.ReportAccumulatorResult; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Future; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * In implementation of the {@link Environment}. + */ +public class RuntimeEnvironment implements Environment { + + private final JobID jobId; + private final JobVertexID jobVertexId; + private final ExecutionAttemptID executionId; + + private final String taskName; + private final String taskNameWithSubtasks; + private final int subtaskIndex; + private final int parallelism; + + private final Configuration jobConfiguration; + private final Configuration taskConfiguration; + + private final ClassLoader userCodeClassLoader; + + private final MemoryManager memManager; + private final IOManager ioManager; + private final BroadcastVariableManager bcVarManager; + private final InputSplitProvider splitProvider; + + private final Map> distCacheEntries; + + private final ResultPartitionWriter[] writers; + private final InputGate[] inputGates; + + private final ActorRef jobManagerActor; + + // ------------------------------------------------------------------------ + + public RuntimeEnvironment(JobID jobId, JobVertexID jobVertexId, ExecutionAttemptID executionId, + String taskName, String taskNameWithSubtasks, + int subtaskIndex, int parallelism, + Configuration jobConfiguration, Configuration taskConfiguration, + ClassLoader userCodeClassLoader, + MemoryManager memManager, IOManager ioManager, + BroadcastVariableManager bcVarManager, + InputSplitProvider splitProvider, + Map> distCacheEntries, + ResultPartitionWriter[] writers, + InputGate[] inputGates, + ActorRef jobManagerActor) { + + checkArgument(parallelism > 0 && subtaskIndex >= 0 && subtaskIndex < parallelism); + + this.jobId = checkNotNull(jobId); + this.jobVertexId = checkNotNull(jobVertexId); + this.executionId = checkNotNull(executionId); + this.taskName = checkNotNull(taskName); + this.taskNameWithSubtasks = checkNotNull(taskNameWithSubtasks); + this.subtaskIndex = subtaskIndex; + this.parallelism = parallelism; + this.jobConfiguration = checkNotNull(jobConfiguration); + this.taskConfiguration = checkNotNull(taskConfiguration); + this.userCodeClassLoader = checkNotNull(userCodeClassLoader); + this.memManager = checkNotNull(memManager); + this.ioManager = checkNotNull(ioManager); + this.bcVarManager = checkNotNull(bcVarManager); + this.splitProvider = checkNotNull(splitProvider); + this.distCacheEntries = checkNotNull(distCacheEntries); + this.writers = checkNotNull(writers); + this.inputGates = checkNotNull(inputGates); + this.jobManagerActor = checkNotNull(jobManagerActor); + } + + + // ------------------------------------------------------------------------ + + @Override + public JobID getJobID() { + return jobId; + } + + @Override + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + @Override + public ExecutionAttemptID getExecutionId() { + return executionId; + } + + @Override + public String getTaskName() { + return taskName; + } + + @Override + public String getTaskNameWithSubtasks() { + return taskNameWithSubtasks; + } + + @Override + public int getNumberOfSubtasks() { + return parallelism; + } + + @Override + public int getIndexInSubtaskGroup() { + return subtaskIndex; + } + + @Override + public Configuration getJobConfiguration() { + return jobConfiguration; + } + + @Override + public Configuration getTaskConfiguration() { + return taskConfiguration; + } + + @Override + public ClassLoader getUserClassLoader() { + return userCodeClassLoader; + } + + @Override + public MemoryManager getMemoryManager() { + return memManager; + } + + @Override + public IOManager getIOManager() { + return ioManager; + } + + @Override + public BroadcastVariableManager getBroadcastVariableManager() { + return bcVarManager; + } + + @Override + public InputSplitProvider getInputSplitProvider() { + return splitProvider; + } + + @Override + public Map> getDistributedCacheEntries() { + return distCacheEntries; + } + + @Override + public ResultPartitionWriter getWriter(int index) { + return writers[index]; + } + + @Override + public ResultPartitionWriter[] getAllWriters() { + return writers; + } + + @Override + public InputGate getInputGate(int index) { + return inputGates[index]; + } + + @Override + public InputGate[] getAllInputGates() { + return inputGates; + } + + @Override + public void reportAccumulators(Map> accumulators) { + AccumulatorEvent evt; + try { + evt = new AccumulatorEvent(getJobID(), accumulators); + } + catch (IOException e) { + throw new RuntimeException("Cannot serialize accumulators to send them to JobManager", e); + } + + ReportAccumulatorResult accResult = new ReportAccumulatorResult(jobId, executionId, evt); + jobManagerActor.tell(accResult, ActorRef.noSender()); + } + + @Override + public ActorRef getJobManager() { + return jobManagerActor; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index e6eee5be34dcf1f028eb280e6549c541517e72ac..f12344bccfc4b2fd6d25392c628bc6ec3dbc1f24 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -19,382 +19,755 @@ package org.apache.flink.runtime.taskmanager; import akka.actor.ActorRef; +import akka.util.Timeout; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.cache.DistributedCache; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; +import org.apache.flink.runtime.execution.CancelTaskException; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.execution.RuntimeEnvironment; +import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.filecache.FileCache; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.ResultPartition; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; -import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.jobgraph.tasks.BarrierTransceiver; +import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier; import org.apache.flink.runtime.memorymanager.MemoryManager; -import org.apache.flink.runtime.messages.ExecutionGraphMessages; -import org.apache.flink.runtime.messages.TaskMessages.UnregisterTask; -import org.apache.flink.runtime.profiling.TaskManagerProfiler; -import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.runtime.messages.TaskMessages; +import org.apache.flink.runtime.messages.TaskMessages.TaskInFinalState; +import org.apache.flink.runtime.messages.TaskManagerMessages.FatalError; +import org.apache.flink.runtime.state.StateHandle; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.concurrent.duration.FiniteDuration; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -public class Task { +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * The Task represents one execution of a parallel subtask on a TaskManager. + * A Task wraps a Flink operator (which may be a user function) and + * runs it, providing all service necessary for example to consume input data, + * produce its results (intermediate result partitions) and communicate + * with the JobManager. + * + *

The Flink operators (implemented as subclasses of + * {@link org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable} have only data + * readers, -writers, and certain event callbacks. The task connects those to the + * network stack and actor messages, and tracks the state of the execution and + * handles exceptions.

+ * + *

Tasks have no knowledge about how they relate to other tasks, or whether they + * are the first attempt to execute the task, or a repeated attempt. All of that + * is only known to the JobManager. All the task knows are its own runnable code, + * the task's configuration, and the IDs of the intermediate results to consume and + * produce (if any).

+ * + *

Each Task is run by one dedicated thread.

+ */ +public class Task implements Runnable { + /** The class logger. */ + private static final Logger LOG = LoggerFactory.getLogger(Task.class); + + /** The tread group that contains all task threads */ + private static final ThreadGroup TASK_THREADS_GROUP = new ThreadGroup("Flink Task Threads"); + /** For atomic state updates */ private static final AtomicReferenceFieldUpdater STATE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Task.class, ExecutionState.class, "executionState"); - /** The log object used for debugging. */ - private static final Logger LOG = LoggerFactory.getLogger(Task.class); - - // -------------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------ + // Constant fields that are part of the initial Task construction + // ------------------------------------------------------------------------ + /** The job that the task belongs to */ private final JobID jobId; + /** The vertex in the JobGraph whose code the task executes */ private final JobVertexID vertexId; + /** The execution attempt of the parallel subtask */ + private final ExecutionAttemptID executionId; + + /** The index of the parallel subtask, in [0, numberOfSubtasks) */ private final int subtaskIndex; - private final int numberOfSubtasks; - - private final ExecutionAttemptID executionId; + /** The number of parallel subtasks for the JobVertex/ExecutionJobVertex that this task belongs to */ + private final int parallelism; + /** The name of the task */ private final String taskName; + /** The name of the task, including the subtask index and the parallelism */ + private final String taskNameWithSubtask; + + /** The job-wide configuration object */ + private final Configuration jobConfiguration; + + /** The task-specific configuration */ + private final Configuration taskConfiguration; + + /** The jar files used by this task */ + private final List requiredJarFiles; + + /** The name of the class that holds the invokable code */ + private final String nameOfInvokableClass; + + /** The handle to the state that the operator was initialized with */ + private final StateHandle operatorState; + + /** The memory manager to be used by this task */ + private final MemoryManager memoryManager; + + /** The I/O manager to be used by this task */ + private final IOManager ioManager; + + /** The BroadcastVariableManager to be used by this task */ + private final BroadcastVariableManager broadcastVariableManager; + + private final ResultPartition[] producedPartitions; + + private final ResultPartitionWriter[] writers; + + private final SingleInputGate[] inputGates; + + private final Map inputGatesById; + + /** The TaskManager actor that spawned this task */ private final ActorRef taskManager; - private final List executionListenerActors = new CopyOnWriteArrayList(); + /** The JobManager actor */ + private final ActorRef jobManager; + + /** All actors that want to be notified about changes in the task's execution state */ + private final List executionListenerActors; + + /** The timeout for all ask operations on actors */ + private final Timeout actorAskTimeout; + + private final LibraryCacheManager libraryCache; + + private final FileCache fileCache; + + private final NetworkEnvironment network; - /** The environment (with the invokable) executed by this task */ - private volatile RuntimeEnvironment environment; + /** The thread that executes the task */ + private final Thread executingThread; + + // ------------------------------------------------------------------------ + // Fields that control the task execution + // ------------------------------------------------------------------------ + + private final AtomicBoolean invokableHasBeenCanceled = new AtomicBoolean(false); + + /** The invokable of this task, if initialized */ + private volatile AbstractInvokable invokable; + /** The current execution state of the task */ - private volatile ExecutionState executionState = ExecutionState.DEPLOYING; + private volatile ExecutionState executionState = ExecutionState.CREATED; + /** The observed exception, in case the task execution failed */ private volatile Throwable failureCause; - // -------------------------------------------------------------------------------------------- + + /** + *

IMPORTANT: This constructor may not start any work that would need to + * be undone in the case of a failing task deployment.

+ */ + public Task(TaskDeploymentDescriptor tdd, + MemoryManager memManager, + IOManager ioManager, + NetworkEnvironment networkEnvironment, + BroadcastVariableManager bcVarManager, + ActorRef taskManagerActor, + ActorRef jobManagerActor, + FiniteDuration actorAskTimeout, + LibraryCacheManager libraryCache, + FileCache fileCache) + { + checkArgument(tdd.getNumberOfSubtasks() > 0); + checkArgument(tdd.getIndexInSubtaskGroup() >= 0); + checkArgument(tdd.getIndexInSubtaskGroup() < tdd.getNumberOfSubtasks()); + + this.jobId = checkNotNull(tdd.getJobID()); + this.vertexId = checkNotNull(tdd.getVertexID()); + this.executionId = checkNotNull(tdd.getExecutionId()); + this.subtaskIndex = tdd.getIndexInSubtaskGroup(); + this.parallelism = tdd.getNumberOfSubtasks(); + this.taskName = checkNotNull(tdd.getTaskName()); + this.taskNameWithSubtask = getTaskNameWithSubtask(taskName, subtaskIndex, parallelism); + this.jobConfiguration = checkNotNull(tdd.getJobConfiguration()); + this.taskConfiguration = checkNotNull(tdd.getTaskConfiguration()); + this.requiredJarFiles = checkNotNull(tdd.getRequiredJarFiles()); + this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName()); + this.operatorState = tdd.getOperatorStates(); + + this.memoryManager = checkNotNull(memManager); + this.ioManager = checkNotNull(ioManager); + this.broadcastVariableManager =checkNotNull(bcVarManager); + + this.jobManager = checkNotNull(jobManagerActor); + this.taskManager = checkNotNull(taskManagerActor); + this.actorAskTimeout = new Timeout(checkNotNull(actorAskTimeout)); + + this.libraryCache = checkNotNull(libraryCache); + this.fileCache = checkNotNull(fileCache); + this.network = checkNotNull(networkEnvironment); + + this.executionListenerActors = new CopyOnWriteArrayList(); + + // create the reader and writer structures + + final String taskNameWithSubtasksAndId = + Task.getTaskNameWithSubtaskAndID(taskName, subtaskIndex, parallelism, executionId); + + List partitions = tdd.getProducedPartitions(); + List consumedPartitions = tdd.getInputGates(); + + // Produced intermediate result partitions + this.producedPartitions = new ResultPartition[partitions.size()]; + this.writers = new ResultPartitionWriter[partitions.size()]; + + for (int i = 0; i < this.producedPartitions.length; i++) { + ResultPartitionDeploymentDescriptor desc = partitions.get(i); + ResultPartitionID partitionId = new ResultPartitionID(desc.getPartitionId(), executionId); + + this.producedPartitions[i] = new ResultPartition( + taskNameWithSubtasksAndId, + jobId, + partitionId, + desc.getPartitionType(), + desc.getNumberOfSubpartitions(), + networkEnvironment.getPartitionManager(), + networkEnvironment.getPartitionConsumableNotifier(), + ioManager, + networkEnvironment.getDefaultIOMode()); + + this.writers[i] = new ResultPartitionWriter(this.producedPartitions[i]); + } + + // Consumed intermediate result partitions + this.inputGates = new SingleInputGate[consumedPartitions.size()]; + this.inputGatesById = new HashMap(); - public Task(JobID jobId, JobVertexID vertexId, int taskIndex, int parallelism, - ExecutionAttemptID executionId, String taskName, ActorRef taskManager) { + for (int i = 0; i < this.inputGates.length; i++) { + SingleInputGate gate = SingleInputGate.create( + taskNameWithSubtasksAndId, consumedPartitions.get(i), networkEnvironment); - this.jobId = jobId; - this.vertexId = vertexId; - this.subtaskIndex = taskIndex; - this.numberOfSubtasks = parallelism; - this.executionId = executionId; - this.taskName = taskName; - this.taskManager = taskManager; + this.inputGates[i] = gate; + inputGatesById.put(gate.getConsumedResultId(), gate); + } + + // finally, create the executing thread, but do not start it + executingThread = new Thread(TASK_THREADS_GROUP, this, taskNameWithSubtask); } - /** - * Returns the ID of the job this task belongs to. - */ + // ------------------------------------------------------------------------ + // Accessors + // ------------------------------------------------------------------------ + public JobID getJobID() { - return this.jobId; + return jobId; } - /** - * Returns the ID of this task vertex. - */ - public JobVertexID getVertexID() { - return this.vertexId; + public JobVertexID getJobVertexId() { + return vertexId; } - /** - * Gets the index of the parallel subtask [0, parallelism). - */ - public int getSubtaskIndex() { + public ExecutionAttemptID getExecutionId() { + return executionId; + } + + public int getIndexInSubtaskGroup() { return subtaskIndex; } - /** - * Gets the total number of subtasks of the task that this subtask belongs to. - */ public int getNumberOfSubtasks() { - return numberOfSubtasks; + return parallelism; } - /** - * Gets the ID of the execution attempt. - */ - public ExecutionAttemptID getExecutionId() { - return executionId; + public String getTaskName() { + return taskName; + } + + public String getTaskNameWithSubtasks() { + return taskNameWithSubtask; + } + + public Configuration getJobConfiguration() { + return jobConfiguration; + } + + public Configuration getTaskConfiguration() { + return this.taskConfiguration; + } + + public ResultPartitionWriter[] getAllWriters() { + return writers; + } + + public SingleInputGate[] getAllInputGates() { + return inputGates; + } + + public ResultPartition[] getProducedPartitions() { + return producedPartitions; + } + + public SingleInputGate getInputGateById(IntermediateDataSetID id) { + return inputGatesById.get(id); + } + + public Thread getExecutingThread() { + return executingThread; } + // ------------------------------------------------------------------------ + // Task Execution + // ------------------------------------------------------------------------ + /** * Returns the current execution state of the task. + * @return The current execution state of the task. */ public ExecutionState getExecutionState() { return this.executionState; } - public void setEnvironment(RuntimeEnvironment environment) { - this.environment = environment; - } - - public RuntimeEnvironment getEnvironment() { - return environment; - } - + /** + * Checks whether the task has failed, is canceled, or is being canceled at the moment. + * @return True is the task in state FAILED, CANCELING, or CANCELED, false otherwise. + */ public boolean isCanceledOrFailed() { return executionState == ExecutionState.CANCELING || executionState == ExecutionState.CANCELED || executionState == ExecutionState.FAILED; } - public String getTaskName() { - if (LOG.isDebugEnabled()) { - return taskName + " (" + executionId + ")"; - } else { - return taskName; - } - } - - public String getTaskNameWithSubtasks() { - if (LOG.isDebugEnabled()) { - return this.taskName + " (" + (this.subtaskIndex + 1) + "/" + this.numberOfSubtasks + - ") (" + executionId + ")"; - } else { - return this.taskName + " (" + (this.subtaskIndex + 1) + "/" + this.numberOfSubtasks + ")"; - } - } - + /** + * If the task has failed, this method gets the exception that caused this task to fail. + * Otherwise this method returns null. + * + * @return The exception that caused the task to fail, or null, if the task has not failed. + */ public Throwable getFailureCause() { return failureCause; } - // ---------------------------------------------------------------------------------------------------------------- - // States and Transitions - // ---------------------------------------------------------------------------------------------------------------- - /** - * Marks the task as finished. This succeeds, if the task was previously in the state - * "RUNNING", otherwise it fails. Failure indicates that the task was either - * canceled, or set to failed. - * - * @return True, if the task correctly enters the state FINISHED. + * Starts the task's thread. */ - public boolean markAsFinished() { - if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, ExecutionState.FINISHED)) { - notifyObservers(ExecutionState.FINISHED, null); - unregisterTask(); - return true; - } - else { - return false; - } + public void startTaskThread() { + executingThread.start(); } - public void markFailed(Throwable error) { + /** + * The core work method that bootstraps the task and executes it code + */ + public void run() { + + // ---------------------------- + // Initial State transition + // ---------------------------- while (true) { ExecutionState current = this.executionState; - - // if canceled, fine. we are done, and the jobmanager has been told - if (current == ExecutionState.CANCELED) { + if (current == ExecutionState.CREATED) { + if (STATE_UPDATER.compareAndSet(this, ExecutionState.CREATED, ExecutionState.DEPLOYING)) { + // success, we can start our work + break; + } + } + else if (current == ExecutionState.FAILED) { + // we were immediately failed. tell the TaskManager that we reached our final state + notifyFinalState(); return; } + else if (current == ExecutionState.CANCELING) { + if (STATE_UPDATER.compareAndSet(this, ExecutionState.CANCELING, ExecutionState.CANCELED)) { + // we were immediately canceled. tell the TaskManager that we reached our final state + notifyFinalState(); + return; + } + } + else { + throw new IllegalStateException("Invalid state for beginning of task operation"); + } + } - // if canceling, we are done, but we cannot be sure that the jobmanager has been told. - // after all, we may have recognized our failure state before the cancelling and never sent a canceled - // message back - else if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { - this.failureCause = error; + // all resource acquisitions and registrations from here on + // need to be undone in the end - notifyObservers(ExecutionState.FAILED, ExceptionUtils.stringifyException(error)); - unregisterTask(); + Map> distributedCacheEntries = new HashMap>(); - return; - } - } - } + AbstractInvokable invokable = null; - public void cancelExecution() { - while (true) { - ExecutionState current = this.executionState; + try { + // ---------------------------- + // Task Bootstrap - We periodically + // check for canceling as a shortcut + // ---------------------------- - // if the task is already canceled (or canceling) or finished or failed, - // then we need not do anything - if (current == ExecutionState.FINISHED || current == ExecutionState.CANCELED || - current == ExecutionState.CANCELING || current == ExecutionState.FAILED) { - return; - } + // first of all, get a user-code classloader + // this may involve downloading the job's JAR files and/or classes + LOG.info("Loading JAR files for task " + taskNameWithSubtask); + final ClassLoader userCodeClassLoader = createUserCodeClassloader(libraryCache); - if (current == ExecutionState.DEPLOYING) { - // directly set to canceled - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) { + // now load the task's invokable code + invokable = loadAndInstantiateInvokable(userCodeClassLoader, nameOfInvokableClass); - notifyObservers(ExecutionState.CANCELED, null); - unregisterTask(); - return; - } + if (isCanceledOrFailed()) { + throw new CancelTaskException(); } - else if (current == ExecutionState.RUNNING) { - // go to canceling and perform the actual task canceling - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELING)) { - - notifyObservers(ExecutionState.CANCELING, null); - try { - this.environment.cancelExecution(); - } - catch (Throwable e) { - LOG.error("Error while cancelling the task.", e); - } - return; + // ---------------------------------------------------------------- + // register the task with the network stack + // this operation may fail if the system does not have enough + // memory to run the necessary data exchanges + // the registration must also strictly be undone + // ---------------------------------------------------------------- + + LOG.info("Registering task at network: " + this); + network.registerTask(this); + + // next, kick off the background copying of files for the distributed cache + try { + for (Map.Entry entry : + DistributedCache.readFileInfoFromConfig(jobConfiguration)) + { + LOG.info("Obtaining local cache file for '" + entry.getKey() + '\''); + Future cp = fileCache.createTmpFile(entry.getKey(), entry.getValue(), jobId); + distributedCacheEntries.put(entry.getKey(), cp); } } - else { - throw new RuntimeException("unexpected state for cancelling: " + current); + catch (Exception e) { + throw new Exception("Exception while adding files to distributed cache.", e); } - } - } - /** - * Sets the tasks to be cancelled and reports a failure back to the master. - */ - public void failExternally(Throwable cause) { - while (true) { - ExecutionState current = this.executionState; - - // if the task is already canceled (or canceling) or finished or failed, - // then we need not do anything - if (current == ExecutionState.CANCELED || current == ExecutionState.CANCELING || current == ExecutionState.FAILED) { - return; + if (isCanceledOrFailed()) { + throw new CancelTaskException(); } - if (current == ExecutionState.FINISHED) { - // Set state to failed in order to correctly unregister task from network environment - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { - notifyObservers(ExecutionState.FAILED, null); + // ---------------------------------------------------------------- + // call the user code initialization methods + // ---------------------------------------------------------------- - return; + TaskInputSplitProvider splitProvider = new TaskInputSplitProvider(jobManager, + jobId, vertexId, executionId, userCodeClassLoader, actorAskTimeout); + + Environment env = new RuntimeEnvironment(jobId, vertexId, executionId, + taskName, taskNameWithSubtask, subtaskIndex, parallelism, + jobConfiguration, taskConfiguration, + userCodeClassLoader, memoryManager, ioManager, broadcastVariableManager, + splitProvider, distributedCacheEntries, + writers, inputGates, jobManager); + + // let the task code create its readers and writers + invokable.setEnvironment(env); + try { + invokable.registerInputOutput(); + } + catch (Exception e) { + throw new Exception("Call to registerInputOutput() of invokable failed", e); + } + + // the very last thing before the actual execution starts running is to inject + // the state into the task. the state is non-empty if this is an execution + // of a task that failed but had backuped state from a checkpoint + if (operatorState != null) { + if (invokable instanceof OperatorStateCarrier) { + ((OperatorStateCarrier) invokable).injectState(operatorState); + } + else { + throw new IllegalStateException("Found operator state for a non-stateful task invokable"); } } - if (current == ExecutionState.DEPLOYING) { - // directly set to canceled - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { - this.failureCause = cause; + // ---------------------------------------------------------------- + // actual task core work + // ---------------------------------------------------------------- - notifyObservers(ExecutionState.FAILED, null); - unregisterTask(); - return; + // we must make strictly sure that the invokable is accessible to teh cancel() call + // by the time we switched to running. + this.invokable = invokable; + + // switch to the RUNNING state, if that fails, we have been canceled/failed in the meantime + if (!STATE_UPDATER.compareAndSet(this, ExecutionState.DEPLOYING, ExecutionState.RUNNING)) { + throw new CancelTaskException(); + } + + // notify everyone that we switched to running. especially the TaskManager needs + // to know this! + notifyObservers(ExecutionState.RUNNING, null); + taskManager.tell(new TaskMessages.UpdateTaskExecutionState( + new TaskExecutionState(jobId, executionId, ExecutionState.RUNNING)), ActorRef.noSender()); + + // make sure the user code classloader is accessible thread-locally + executingThread.setContextClassLoader(userCodeClassLoader); + + // run the invokable + invokable.invoke(); + + // make sure, we enter the catch block if the task leaves the invoke() method due + // to the fact that it has been canceled + if (isCanceledOrFailed()) { + throw new CancelTaskException(); + } + + // ---------------------------------------------------------------- + // finalization of a successful execution + // ---------------------------------------------------------------- + + // finish the produced partitions. if this fails, we consider the execution failed. + for (ResultPartition partition : producedPartitions) { + if (partition != null) { + partition.finish(); } } - else if (current == ExecutionState.RUNNING) { - // go to canceling and perform the actual task canceling - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { - try { - this.environment.cancelExecution(); + + // try to mark the task as finished + // if that fails, the task was canceled/failed in the meantime + if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, ExecutionState.FINISHED)) { + notifyObservers(ExecutionState.FINISHED, null); + } + else { + throw new CancelTaskException(); + } + } + catch (Throwable t) { + + // ---------------------------------------------------------------- + // the execution failed. either the invokable code properly failed, or + // an exception was thrown as a side effect of cancelling + // ---------------------------------------------------------------- + + try { + // transition into our final state. we should be either in RUNNING, CANCELING, or FAILED + // loop for multiple retries during concurrent state changes via calls to cancel() or + // to failExternally() + while (true) { + ExecutionState current = this.executionState; + if (current == ExecutionState.RUNNING || current == ExecutionState.DEPLOYING) { + if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { + // proper failure of the task. record the exception as the root cause + failureCause = t; + notifyObservers(ExecutionState.FAILED, t); + + // in case of an exception during execution, we still call "cancel()" on the task + if (invokable != null && this.invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) { + try { + invokable.cancel(); + } + catch (Throwable t2) { + LOG.error("Error while canceling task " + taskNameWithSubtask, t2); + } + } + break; + } } - catch (Throwable e) { - LOG.error("Error while cancelling the task.", e); + else if (current == ExecutionState.CANCELING) { + if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) { + notifyObservers(ExecutionState.CANCELED, null); + break; + } } + else if (current == ExecutionState.FAILED) { + // in state failed already, no transition necessary any more + break; + } + // unexpected state, go to failed + else if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.FAILED)) { + LOG.error("Unexpected state in Task during an exception: " + current); + break; + } + // else fall through the loop and + } + } + catch (Throwable tt) { + String message = "FATAL - exception in task exception handler"; + LOG.error(message, tt); + notifyFatalError(message, tt); + } + } + finally { + try { + LOG.info("Freeing task resources for " + taskNameWithSubtask); + + // free the network resources + network.unregisterTask(this); + + if (invokable != null) { + memoryManager.releaseAll(invokable); + } - this.failureCause = cause; + // remove all of the tasks library resources + libraryCache.unregisterTask(jobId, executionId); - notifyObservers(ExecutionState.FAILED, null); - unregisterTask(); + // remove all files in the distributed cache + removeCachedFiles(distributedCacheEntries, fileCache); - return; - } + notifyFinalState(); } - else { - throw new RuntimeException("unexpected state for failing the task: " + current); + catch (Throwable t) { + // an error in the resource cleanup is fatal + String message = "FATAL - exception in task resource cleanup"; + LOG.error(message, t); + notifyFatalError(message, t); } } } - public void cancelingDone() { - while (true) { - ExecutionState current = this.executionState; + private ClassLoader createUserCodeClassloader(LibraryCacheManager libraryCache) throws Exception { + long startDownloadTime = System.currentTimeMillis(); - if (current == ExecutionState.CANCELED || current == ExecutionState.FAILED) { - return; - } - if (!(current == ExecutionState.RUNNING || current == ExecutionState.CANCELING)) { - LOG.error(String.format("Unexpected state transition in Task: %s -> %s", current, ExecutionState.CANCELED)); - } + // triggers the download of all missing jar files from the job manager + libraryCache.registerTask(jobId, executionId, requiredJarFiles); - if (STATE_UPDATER.compareAndSet(this, current, ExecutionState.CANCELED)) { - notifyObservers(ExecutionState.CANCELED, null); - unregisterTask(); - return; - } + LOG.debug("Register task {} at library cache manager took {} milliseconds", + executionId, System.currentTimeMillis() - startDownloadTime); + + ClassLoader userCodeClassLoader = libraryCache.getClassLoader(jobId); + if (userCodeClassLoader == null) { + throw new Exception("No user code classloader available."); } + return userCodeClassLoader; } - /** - * Starts the execution of this task. - */ - public boolean startExecution() { - LOG.info("Starting execution of task {}", this.getTaskName()); - if (STATE_UPDATER.compareAndSet(this, ExecutionState.DEPLOYING, ExecutionState.RUNNING)) { - final Thread thread = this.environment.getExecutingThread(); - thread.start(); - return true; + private AbstractInvokable loadAndInstantiateInvokable(ClassLoader classLoader, String className) throws Exception { + Class invokableClass; + try { + invokableClass = Class.forName(className, true, classLoader) + .asSubclass(AbstractInvokable.class); } - else { - return false; + catch (Throwable t) { + throw new Exception("Could not load the task's invokable class.", t); + } + try { + return invokableClass.newInstance(); + } + catch (Throwable t) { + throw new Exception("Could not instantiate the task's invokable class.", t); } } - /** - * Unregisters the task from the central memory manager. - */ - public void unregisterMemoryManager(MemoryManager memoryManager) { - RuntimeEnvironment env = this.environment; - if (memoryManager != null && env != null) { - memoryManager.releaseAll(env.getInvokable()); + private void removeCachedFiles(Map> entries, FileCache fileCache) { + // cancel and release all distributed cache files + try { + for (Map.Entry> entry : entries.entrySet()) { + String name = entry.getKey(); + try { + fileCache.deleteTmpFile(name, jobId); + } + catch (Exception e) { + // unpleasant, but we continue + LOG.error("Distributed Cache could not remove cached file registered under '" + + name + "'.", e); + } + } + } + catch (Throwable t) { + LOG.error("Error while removing cached local files from distributed cache."); } } - protected void unregisterTask() { - taskManager.tell(new UnregisterTask(executionId), ActorRef.noSender()); + private void notifyFinalState() { + taskManager.tell(new TaskInFinalState(executionId), ActorRef.noSender()); } - // ----------------------------------------------------------------------------------------------------------------- - // Task Profiling - // ----------------------------------------------------------------------------------------------------------------- + private void notifyFatalError(String message, Throwable cause) { + taskManager.tell(new FatalError(message, cause), ActorRef.noSender()); + } - /** - * Registers the task manager profiler with the task. - */ - public void registerProfiler(TaskManagerProfiler taskManagerProfiler, Configuration jobConfiguration) { - taskManagerProfiler.registerTask(this, jobConfiguration); + // ---------------------------------------------------------------------------------------------------------------- + // Canceling / Failing the task from the outside + // ---------------------------------------------------------------------------------------------------------------- + + public void cancelExecution() { + LOG.info("Attempting to cancel task " + taskNameWithSubtask); + if (cancelOrFailAndCancelInvokable(ExecutionState.CANCELING)) { + notifyObservers(ExecutionState.CANCELING, null); + } } /** - * Unregisters the task from the task manager profiler. + * Sets the tasks to be cancelled and reports a failure back to the master. */ - public void unregisterProfiler(TaskManagerProfiler taskManagerProfiler) { - if (taskManagerProfiler != null) { - taskManagerProfiler.unregisterTask(this.executionId); + public void failExternally(Throwable cause) { + LOG.info("Attempting to fail task externally " + taskNameWithSubtask); + if (cancelOrFailAndCancelInvokable(ExecutionState.FAILED)) { + failureCause = cause; + notifyObservers(ExecutionState.FAILED, cause); } } - // ------------------------------------------------------------------------ - // Intermediate result partitions - // ------------------------------------------------------------------------ - - public SingleInputGate[] getInputGates() { - return environment != null ? environment.getAllInputGates() : null; - } + private boolean cancelOrFailAndCancelInvokable(ExecutionState targetState) { + while (true) { + ExecutionState current = this.executionState; - public ResultPartitionWriter[] getWriters() { - return environment != null ? environment.getAllWriters() : null; - } + // if the task is already canceled (or canceling) or finished or failed, + // then we need not do anything + if (current.isTerminal() || current == ExecutionState.CANCELING) { + return false; + } - public ResultPartition[] getProducedPartitions() { - return environment != null ? environment.getProducedPartitions() : null; + if (current == ExecutionState.DEPLOYING || current == ExecutionState.CREATED) { + if (STATE_UPDATER.compareAndSet(this, current, targetState)) { + // if we manage this state transition, then the invokable gets never called + // we need not call cancel on it + return true; + } + } + else if (current == ExecutionState.RUNNING) { + if (STATE_UPDATER.compareAndSet(this, ExecutionState.RUNNING, targetState)) { + // we are canceling / failing out of the running state + // we need to cancel the invokable + if (invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) { + LOG.info("Triggering cancellation of task code {} ({}).", taskNameWithSubtask, executionId); + + // because the canceling may block on user code, we cancel from a separate thread + Runnable canceler = new TaskCanceler(LOG, invokable, executingThread, taskNameWithSubtask); + Thread cancelThread = new Thread(executingThread.getThreadGroup(), canceler, + "Canceler for " + taskNameWithSubtask); + cancelThread.start(); + } + return true; + } + } + else { + throw new IllegalStateException("Unexpected task state: " + current); + } + } } - // -------------------------------------------------------------------------------------------- - // State Listeners - // -------------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------ + // State Listeners + // ------------------------------------------------------------------------ public void registerExecutionListener(ActorRef listener) { executionListenerActors.add(listener); @@ -404,25 +777,146 @@ public class Task { executionListenerActors.remove(listener); } - private void notifyObservers(ExecutionState newState, String message) { - if (LOG.isInfoEnabled()) { - LOG.info(getTaskNameWithSubtasks() + " switched to " + newState + (message == null ? "" : " : " + message)); + private void notifyObservers(ExecutionState newState, Throwable error) { + if (error == null) { + LOG.info(taskNameWithSubtask + " switched to " + newState); + } + else { + LOG.info(taskNameWithSubtask + " switched to " + newState + " with exception.", error); } + TaskExecutionState stateUpdate = new TaskExecutionState(jobId, executionId, newState, error); + TaskMessages.UpdateTaskExecutionState actorMessage = new + TaskMessages.UpdateTaskExecutionState(stateUpdate); + for (ActorRef listener : executionListenerActors) { - listener.tell(new ExecutionGraphMessages.ExecutionStateChanged( - jobId, vertexId, taskName, numberOfSubtasks, subtaskIndex, - executionId, newState, System.currentTimeMillis(), message), - ActorRef.noSender()); + listener.tell(actorMessage, ActorRef.noSender()); } } - // -------------------------------------------------------------------------------------------- - // Utilities - // -------------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------ + // Notifications on the invokable + // ------------------------------------------------------------------------ + + public void triggerCheckpointBarrier(final long checkpointID) { + AbstractInvokable invokabe = this.invokable; + + if (executionState == ExecutionState.RUNNING && invokabe != null) { + if (invokabe instanceof BarrierTransceiver) { + final BarrierTransceiver barrierTransceiver = (BarrierTransceiver) invokabe; + final Logger logger = LOG; + + Thread caller = new Thread("Barrier emitter") { + @Override + public void run() { + try { + barrierTransceiver.broadcastBarrierFromSource(checkpointID); + } + catch (Throwable t) { + logger.error("Error while triggering checkpoint barriers", t); + } + } + }; + caller.setDaemon(true); + caller.start(); + } + else { + LOG.error("Task received a checkpoint request, but is not a checkpointing task - " + + taskNameWithSubtask); + } + } + else { + LOG.debug("Ignoring request to trigger a checkpoint barrier"); + } + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ @Override public String toString() { return getTaskNameWithSubtasks() + " [" + executionState + ']'; } + + // ------------------------------------------------------------------------ + // Task Names + // ------------------------------------------------------------------------ + + public static String getTaskNameWithSubtask(String name, int subtask, int numSubtasks) { + return name + " (" + (subtask+1) + '/' + numSubtasks + ')'; + } + + public static String getTaskNameWithSubtaskAndID(String name, int subtask, int numSubtasks, ExecutionAttemptID id) { + return name + " (" + (subtask+1) + '/' + numSubtasks + ") (" + id + ')'; + } + + /** + * This runner calls cancel() on the invokable and periodically interrupts the + * thread until it has terminated. + */ + private static class TaskCanceler implements Runnable { + + private final Logger logger; + private final AbstractInvokable invokable; + private final Thread executer; + private final String taskName; + + public TaskCanceler(Logger logger, AbstractInvokable invokable, Thread executer, String taskName) { + this.logger = logger; + this.invokable = invokable; + this.executer = executer; + this.taskName = taskName; + } + + @Override + public void run() { + try { + // the user-defined cancel method may throw errors. + // we need do continue despite that + try { + invokable.cancel(); + } + catch (Throwable t) { + logger.error("Error while canceling the task", t); + } + + // interrupt the running thread initially + executer.interrupt(); + try { + executer.join(10000); + } + catch (InterruptedException e) { + // we can ignore this + } + + // it is possible that the user code does not react immediately. for that + // reason, we spawn a separate thread that repeatedly interrupts the user code until + // it exits + while (executer.isAlive()) { + + // build the stack trace of where the thread is stuck, for the log + StringBuilder bld = new StringBuilder(); + StackTraceElement[] stack = executer.getStackTrace(); + for (StackTraceElement e : stack) { + bld.append(e).append('\n'); + } + + logger.warn("Task '{}' did not react to cancelling signal, but is stuck in method:\n {}", + taskName, bld.toString()); + + executer.interrupt(); + try { + executer.join(5000); + } + catch (InterruptedException e) { + // we can ignore this + } + } + } + catch (Throwable t) { + logger.error("Error in the task canceler", t); + } + } + } } diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskManagerMessages.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskManagerMessages.scala index c81830c36b14ecf60711a8ade666765aa2dd71b7..b12f1b54e364a37734e330168858429792e965a4 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskManagerMessages.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskManagerMessages.scala @@ -24,7 +24,15 @@ import org.apache.flink.runtime.instance.InstanceID * Miscellaneous actor messages exchanged with the TaskManager. */ object TaskManagerMessages { - + + /** + * This message informs the TaskManager about a fatal error that prevents + * it from continuing. + * + * @param description The description of the problem + */ + case class FatalError(description: String, cause: Throwable) + /** * Tells the task manager to send a heartbeat message to the job manager. */ @@ -49,7 +57,7 @@ object TaskManagerMessages { // -------------------------------------------------------------------------- - // Utility messages used for notifications during TaskManager startup + // Reporting the current TaskManager stack trace // -------------------------------------------------------------------------- /** diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskMessages.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskMessages.scala index c8c57265cd51b916710c1064113b841b84cde4b4..b1a08caff991a41a7fda67bfeea0eb3175036972 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskMessages.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/messages/TaskMessages.scala @@ -67,12 +67,12 @@ object TaskMessages { extends TaskMessage /** - * Unregister the task identified by [[executionID]] from the TaskManager. - * Sent to the TaskManager by futures and callbacks. + * Notifies the TaskManager that the task has reached its final state, + * either FINISHED, CANCELED, or FAILED. * * @param executionID The task's execution attempt ID. */ - case class UnregisterTask(executionID: ExecutionAttemptID) + case class TaskInFinalState(executionID: ExecutionAttemptID) extends TaskMessage diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index 747cc8597aadb1e4ebf2c822e3a1d8a193d0ff39..bdefea6b314c5d0a5369fd38c0081875fcec99a2 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -21,7 +21,7 @@ package org.apache.flink.runtime.taskmanager import java.io.{File, IOException} import java.net.{InetAddress, InetSocketAddress} import java.util -import java.util.concurrent.{TimeUnit, FutureTask} +import java.util.concurrent.TimeUnit import java.lang.reflect.Method import java.lang.management.{GarbageCollectorMXBean, ManagementFactory, MemoryMXBean} @@ -36,16 +36,13 @@ import com.codahale.metrics.jvm.{MemoryUsageGaugeSet, GarbageCollectorMetricSet} import com.fasterxml.jackson.databind.ObjectMapper import grizzled.slf4j.Logger -import org.apache.flink.api.common.cache.DistributedCache import org.apache.flink.configuration._ -import org.apache.flink.core.fs.Path import org.apache.flink.runtime.{ActorSynchronousLogging, ActorLogMessages} import org.apache.flink.runtime.akka.AkkaUtils import org.apache.flink.runtime.blob.{BlobService, BlobCache} import org.apache.flink.runtime.broadcast.BroadcastVariableManager import org.apache.flink.runtime.deployment.{InputChannelDeploymentDescriptor, TaskDeploymentDescriptor} import org.apache.flink.runtime.execution.librarycache.{BlobLibraryCacheManager, FallbackLibraryCacheManager, LibraryCacheManager} -import org.apache.flink.runtime.execution.{CancelTaskException, ExecutionState, RuntimeEnvironment} import org.apache.flink.runtime.executiongraph.ExecutionAttemptID import org.apache.flink.runtime.filecache.FileCache import org.apache.flink.runtime.instance.{HardwareDescription, InstanceConnectionInfo, InstanceID} @@ -54,7 +51,6 @@ import org.apache.flink.runtime.io.disk.iomanager.{IOManager, IOManagerAsync} import org.apache.flink.runtime.io.network.NetworkEnvironment import org.apache.flink.runtime.io.network.netty.NettyConfig import org.apache.flink.runtime.jobgraph.IntermediateDataSetID -import org.apache.flink.runtime.jobgraph.tasks.{OperatorStateCarrier,BarrierTransceiver} import org.apache.flink.runtime.jobmanager.JobManager import org.apache.flink.runtime.memorymanager.{MemoryManager, DefaultMemoryManager} import org.apache.flink.runtime.messages.CheckpointingMessages.{CheckpointingMessage, BarrierReq} @@ -67,9 +63,6 @@ import org.apache.flink.runtime.process.ProcessReaper import org.apache.flink.runtime.security.SecurityUtils import org.apache.flink.runtime.security.SecurityUtils.FlinkSecuredRunner import org.apache.flink.runtime.util.{MathUtils, EnvironmentInformation} -import org.apache.flink.util.ExceptionUtils - -import org.slf4j.LoggerFactory import scala.concurrent._ import scala.concurrent.duration._ @@ -136,13 +129,13 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { protected val resources = HardwareDescription.extractFromSystem(memoryManager.getMemorySize) /** Registry of all tasks currently executed by this TaskManager */ - protected val runningTasks = new util.concurrent.ConcurrentHashMap[ExecutionAttemptID, Task]() + protected val runningTasks = new java.util.HashMap[ExecutionAttemptID, Task]() /** Handler for shared broadcast variables (shared between multiple Tasks) */ protected val bcVarManager = new BroadcastVariableManager() /** Handler for distributed files cached by this TaskManager */ - protected val fileCache = new FileCache() + protected val fileCache = new FileCache(config.configuration) /** Registry of metrics periodically transmitted to the JobManager */ private val metricRegistry = TaskManager.createMetricsRegistry() @@ -282,6 +275,9 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { case Disconnect(msg) => handleJobManagerDisconnect(sender(), "JobManager requested disconnect: " + msg) + + case FatalError(message, cause) => + killTaskManagerFatal(message, cause) } /** @@ -344,12 +340,13 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { // state transition case updateMsg @ UpdateTaskExecutionState(taskExecutionState: TaskExecutionState) => + + // we receive these from our tasks and forward them to the JobManager currentJobManager foreach { jobManager => { val futureResponse = (jobManager ? updateMsg)(askTimeout) - + val executionID = taskExecutionState.getID - val executionState = taskExecutionState.getExecutionState futureResponse.mapTo[Boolean].onComplete { // IMPORTANT: In the future callback, we cannot directly modify state @@ -359,21 +356,16 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { self ! FailTask(executionID, new Exception("Task has been cancelled on the JobManager.")) } - - if (!result || executionState.isTerminal) { - self ! UnregisterTask(executionID) - } + case Failure(t) => self ! FailTask(executionID, new Exception( "Failed to send ExecutionStateChange notification to JobManager")) - - self ! UnregisterTask(executionID) }(context.dispatcher) } } // removes the task from the TaskManager and frees all its resources - case UnregisterTask(executionID) => + case TaskInFinalState(executionID) => unregisterTaskAndNotifyFinalState(executionID) // starts a new task on the TaskManager @@ -383,35 +375,22 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { // marks a task as failed for an external reason // external reasons are reasons other than the task code itself throwing an exception case FailTask(executionID, cause) => - Option(runningTasks.get(executionID)) match { - case Some(task) => - - // execute failing operation concurrently - implicit val executor = context.dispatcher - Future { - task.failExternally(cause) - }.onFailure{ - case t: Throwable => log.error(s"Could not fail task ${task} externally.", t) - } - case None => + val task = runningTasks.get(executionID) + if (task != null) { + task.failExternally(cause) + } else { + log.debug(s"Cannot find task to fail for execution ${executionID})") } // cancels a task case CancelTask(executionID) => - Option(runningTasks.get(executionID)) match { - case Some(task) => - // execute cancel operation concurrently - implicit val executor = context.dispatcher - Future { - task.cancelExecution() - }.onFailure{ - case t: Throwable => log.error("Could not cancel task " + task, t) - } - - sender ! new TaskOperationResult(executionID, true) - - case None => - sender ! new TaskOperationResult(executionID, false, + val task = runningTasks.get(executionID) + if (task != null) { + task.cancelExecution() + sender ! new TaskOperationResult(executionID, true) + } else { + log.debug(s"Cannot find task to cancel for execution ${executionID})") + sender ! new TaskOperationResult(executionID, false, "No task with that execution ID was found.") } } @@ -430,24 +409,11 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { log.debug(s"[FT-TaskManager] Barrier $checkpointID request received " + s"for attempt $attemptID.") - Option(runningTasks.get(attemptID)) match { - case Some(i) => - if (i.getExecutionState == ExecutionState.RUNNING) { - i.getEnvironment.getInvokable match { - case barrierTransceiver: BarrierTransceiver => - new Thread(new Runnable { - override def run(): Unit = - barrierTransceiver.broadcastBarrierFromSource(checkpointID) - }).start() - - case _ => log.error("Taskmanager received a checkpoint request for " + - s"non-checkpointing task $attemptID.") - } - } - - case None => - // may always happen in case of canceled/finished tasks - log.debug(s"Taskmanager received a checkpoint request for unknown task $attemptID.") + val task = runningTasks.get(attemptID) + if (task != null) { + task.triggerCheckpointBarrier(checkpointID) + } else { + log.debug(s"Taskmanager received a checkpoint request for unknown task $attemptID.") } // unknown checkpoint message @@ -770,8 +736,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { } } } - - + // -------------------------------------------------------------------------- // Task Operations // -------------------------------------------------------------------------- @@ -784,130 +749,46 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { * @param tdd TaskDeploymentDescriptor describing the task to be executed on this [[TaskManager]] */ private def submitTask(tdd: TaskDeploymentDescriptor): Unit = { - val slot = tdd.getTargetSlotNumber - - if (!isConnected) { - sender ! Failure( - new IllegalStateException("TaskManager is not associated with a JobManager.") - ) - } else if (slot < 0 || slot >= numberOfSlots) { - sender ! Failure(new Exception(s"Target slot $slot does not exist on TaskManager.")) - } else { - sender ! Acknowledge - - Future { - initializeTask(tdd) - }(context.dispatcher) - } - } - - /** Sets up a [[org.apache.flink.runtime.execution.RuntimeEnvironment]] for the task and starts - * its execution in a separate thread. - * - * @param tdd TaskDeploymentDescriptor describing the task to be executed on this [[TaskManager]] - */ - private def initializeTask(tdd: TaskDeploymentDescriptor): Unit ={ - val jobID = tdd.getJobID - val vertexID = tdd.getVertexID - val executionID = tdd.getExecutionId - val taskIndex = tdd.getIndexInSubtaskGroup - val numSubtasks = tdd.getNumberOfSubtasks - var startRegisteringTask = 0L - var task: Task = null - try { - val userCodeClassLoader = libraryCacheManager match { - case Some(manager) => - if (log.isDebugEnabled) { - startRegisteringTask = System.currentTimeMillis() - } - - // triggers the download of all missing jar files from the job manager - manager.registerTask(jobID, executionID, tdd.getRequiredJarFiles) - - if (log.isDebugEnabled) { - log.debug(s"Register task $executionID at library cache manager " + - s"took ${(System.currentTimeMillis() - startRegisteringTask) / 1000.0}s") - } - - manager.getClassLoader(jobID) - case None => throw new IllegalStateException("There is no valid library cache manager.") - } - - if (userCodeClassLoader == null) { - throw new RuntimeException("No user code Classloader available.") - } - - task = new Task(jobID, vertexID, taskIndex, numSubtasks, executionID, - tdd.getTaskName, self) - - Option(runningTasks.put(executionID, task)) match { - case Some(_) => throw new RuntimeException( - s"TaskManager contains already a task with executionID $executionID.") + // grab some handles and sanity check on the fly + val jobManagerActor = currentJobManager match { + case Some(jm) => jm case None => + throw new IllegalStateException("TaskManager is not associated with a JobManager.") } - - val env = currentJobManager match { - case Some(jobManager) => - val splitProvider = new TaskInputSplitProvider(jobManager, jobID, vertexID, - executionID, userCodeClassLoader, askTimeout) - - new RuntimeEnvironment(jobManager, task, tdd, userCodeClassLoader, - memoryManager, ioManager, splitProvider, bcVarManager, network) - - case None => throw new IllegalStateException( - "TaskManager has not yet been registered at a JobManager.") - } - - task.setEnvironment(env) - - //inject operator state - if (tdd.getOperatorStates != null) { - task.getEnvironment.getInvokable match { - case opStateCarrier: OperatorStateCarrier => - opStateCarrier.injectState(tdd.getOperatorStates) - } + val libCache = libraryCacheManager match { + case Some(manager) => manager + case None => throw new IllegalStateException("There is no valid library cache manager.") } - // register the task with the network stack and profiles - log.info(s"Register task $task.") - network.registerTask(task) - - val cpTasks = new util.HashMap[String, FutureTask[Path]]() - - for (entry <- DistributedCache.readFileInfoFromConfig(tdd.getJobConfiguration).asScala) { - val cp = fileCache.createTmpFile(entry.getKey, entry.getValue, jobID) - cpTasks.put(entry.getKey, cp) + val slot = tdd.getTargetSlotNumber + if (slot < 0 || slot >= numberOfSlots) { + throw new IllegalArgumentException(s"Target slot $slot does not exist on TaskManager.") } - env.addCopyTasksForCacheFile(cpTasks) - if (!task.startExecution()) { - throw new RuntimeException("Cannot start task. Task was canceled or failed.") + // create the task. this does not grab any TaskManager resources or download + // and libraries - the operation does not block + val execId = tdd.getExecutionId + val task = new Task(tdd, memoryManager, ioManager, network, bcVarManager, + self, jobManagerActor, config.timeout, libCache, fileCache) + + // add the task to the map + val prevTask = runningTasks.put(execId, task) + if (prevTask != null) { + // already have a task for that ID, put if back and report an error + runningTasks.put(execId, prevTask) + throw new IllegalStateException("TaskManager already contains a task for id " + execId) } - - self ! UpdateTaskExecutionState( - new TaskExecutionState(jobID, executionID, ExecutionState.RUNNING) - ) - } catch { - case t: Throwable => - if (!t.isInstanceOf[CancelTaskException]) { - log.error("Could not instantiate task with execution ID " + executionID, t) - } - - try { - if (task != null) { - task.failExternally(t) - removeAllTaskResources(task) - } - - libraryCacheManager foreach { _.unregisterTask(jobID, executionID) } - } catch { - case t: Throwable => log.error("Error during cleanup of task deployment.", t) - } - - self ! UpdateTaskExecutionState( - new TaskExecutionState(jobID, executionID, ExecutionState.FAILED, t) - ) + + // all good, we kick off the task, which performs its own initialization + task.startTaskThread() + + sender ! Acknowledge + } + catch { + case t: Throwable => + log.error("SubmitTask failed", t) + sender ! Failure(t) } } @@ -927,19 +808,20 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { val errors: Seq[String] = partitionInfos.flatMap { info => val (resultID, partitionInfo) = info - val reader = task.getEnvironment.getInputGateById(resultID) + val reader = task.getInputGateById(resultID) if (reader != null) { Future { try { reader.updateInputChannel(partitionInfo) - } catch { + } + catch { case t: Throwable => log.error(s"Could not update input data location for task " + s"${task.getTaskName}. Trying to fail task.", t) try { - task.markFailed(t) + task.failExternally(t) } catch { case t: Throwable => @@ -977,20 +859,20 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { private def cancelAndClearEverything(cause: Throwable) { if (runningTasks.size > 0) { log.info("Cancelling all computations and discarding all cached data.") - + for (t <- runningTasks.values().asScala) { t.failExternally(cause) - unregisterTaskAndNotifyFinalState(t.getExecutionId) } + runningTasks.clear() } } private def unregisterTaskAndNotifyFinalState(executionID: ExecutionAttemptID): Unit = { - Option(runningTasks.remove(executionID)) match { - case Some(task) => - - // mark the task as failed if it is not yet in a final state + val task = runningTasks.remove(executionID) + if (task != null) { + + // the task must be in a terminal state if (!task.getExecutionState.isTerminal) { try { task.failExternally(new Exception("Task is being removed from TaskManager")) @@ -999,66 +881,15 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging { } } - log.info(s"Unregister task with execution ID $executionID.") - removeAllTaskResources(task) - libraryCacheManager foreach { _.unregisterTask(task.getJobID, executionID) } - - log.info(s"Updating FINAL execution state of ${task.getTaskName} " + - s"(${task.getExecutionId}) to ${task.getExecutionState}.") + log.info(s"Unregistering task and sending final execution state " + + s"${task.getExecutionState} to JobManager for task ${task.getTaskName} " + + s"(${task.getExecutionId})") self ! UpdateTaskExecutionState(new TaskExecutionState( task.getJobID, task.getExecutionId, task.getExecutionState, task.getFailureCause)) - - case None => - log.debug(s"Cannot find task with ID $executionID to unregister.") } - } - - /** - * This method cleans up the resources of a task in the distributed cache, - * network stack and the memory manager. - * - * If the cleanup in the network stack or memory manager fails, this is considered - * a fatal problem (critical resource leak) and causes the TaskManager to quit. - * A TaskManager JVM restart is the best safe way to fix that error. - * - * @param task The Task whose resources should be cleared. - */ - private def removeAllTaskResources(task: Task): Unit = { - - // release the critical things first, and fail fatally if it does not work - - // this releases all task resources, like buffer pools and intermediate result - // partitions being built. If this fails, the TaskManager is in serious trouble, - // as this is a massive resource leak. We kill the TaskManager in that case, - // to recover through a clean JVM start - try { - network.unregisterTask(task) - } catch { - case t: Throwable => - killTaskManagerFatal("Failed to unregister task resources from network stack", t) - } - - // safety net to release all the task's memory - try { - task.unregisterMemoryManager(memoryManager) - } catch { - case t: Throwable => - killTaskManagerFatal("Failed to unregister task memory from memory manager", t) - } - - // release temp files from the distributed cache - if (task.getEnvironment != null) { - try { - for (entry <- DistributedCache.readFileInfoFromConfig( - task.getEnvironment.getJobConfiguration).asScala) { - fileCache.deleteTmpFile(entry.getKey, entry.getValue, task.getJobID) - } - } catch { - // this is pretty unpleasant, but not a reason to give up immediately - case e: Exception => log.error( - "Error cleaning up local temp files from the distributed cache.", e) - } + else { + log.error(s"Cannot find task with ID $executionID to unregister.") } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index d13db057c6774e80525048c7d1ef5054ffdeae97..fdf41f04e802fd2732d24867283af113a5dd597a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -19,8 +19,8 @@ package org.apache.flink.runtime.io.network.partition.consumer; import com.google.common.collect.Lists; + import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.buffer.BufferPool; @@ -36,6 +36,7 @@ import org.apache.flink.runtime.io.network.util.TestPartitionProducer; import org.apache.flink.runtime.io.network.util.TestProducerSource; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + import org.junit.Test; import java.util.Collections; @@ -93,7 +94,7 @@ public class LocalInputChannelTest { partitionIds[i] = new ResultPartitionID(); final ResultPartition partition = new ResultPartition( - mock(Environment.class), + "Test Name", jobId, partitionIds[i], ResultPartitionType.PIPELINED, @@ -222,7 +223,7 @@ public class LocalInputChannelTest { checkArgument(numberOfExpectedBuffersPerChannel >= 1); this.inputGate = new SingleInputGate( - mock(Environment.class), + "Test Name", new IntermediateDataSetID(), subpartitionIndex, numberOfInputChannels); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index be66aeb93199ecd83be867f426115bba5f812201..9a7ffe53f0b8906eaa4ad10a1853d3854988381b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -22,8 +22,6 @@ import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.execution.RuntimeEnvironment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; @@ -61,7 +59,7 @@ public class SingleInputGateTest { public void testBasicGetNextLogic() throws Exception { // Setup final SingleInputGate inputGate = new SingleInputGate( - mock(Environment.class), new IntermediateDataSetID(), 0, 2); + "Test Task Name", new IntermediateDataSetID(), 0, 2); final TestInputChannel[] inputChannels = new TestInputChannel[]{ new TestInputChannel(inputGate, 0), @@ -107,7 +105,7 @@ public class SingleInputGateTest { // Setup reader with one local and one unknown input channel final IntermediateDataSetID resultId = new IntermediateDataSetID(); - final SingleInputGate inputGate = new SingleInputGate(mock(Environment.class), resultId, 0, 2); + final SingleInputGate inputGate = new SingleInputGate("Test Task Name", resultId, 0, 2); final BufferPool bufferPool = mock(BufferPool.class); when(bufferPool.getNumberOfRequiredMemorySegments()).thenReturn(2); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java index efc02de2d8b228c5b9a6fb68577693e42857cea7..2dafaa26919eae76fb791469369eee9f974c7a6a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; @@ -30,7 +29,6 @@ import java.util.List; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkElementIndex; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; /** @@ -50,7 +48,7 @@ public class TestSingleInputGate { checkArgument(numberOfInputChannels >= 1); this.inputGate = spy(new SingleInputGate( - mock(Environment.class), new IntermediateDataSetID(), 0, numberOfInputChannels)); + "Test Task Name", new IntermediateDataSetID(), 0, numberOfInputChannels)); this.inputChannels = new TestInputChannel[numberOfInputChannels]; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 90132ffdab2642c116b4c8b54041433018b910ba..050f43a92af6ec7d05c98e37cb1b74dae19886cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -18,14 +18,13 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; public class UnionInputGateTest { @@ -39,9 +38,9 @@ public class UnionInputGateTest { @Test(timeout = 120 * 1000) public void testBasicGetNextLogic() throws Exception { // Setup - final Environment env = mock(Environment.class); - final SingleInputGate ig1 = new SingleInputGate(env, new IntermediateDataSetID(), 0, 3); - final SingleInputGate ig2 = new SingleInputGate(env, new IntermediateDataSetID(), 0, 5); + final String testTaskName = "Test Task"; + final SingleInputGate ig1 = new SingleInputGate(testTaskName, new IntermediateDataSetID(), 0, 3); + final SingleInputGate ig2 = new SingleInputGate(testTaskName, new IntermediateDataSetID(), 0, 5); final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2}); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 16cd66e6eea8c69d6aaa1283ab29867f3bcf2302..735f67ef3657608aaf7bf3c6ea6946f5fe67f663 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.operators.testutils; +import akka.actor.ActorRef; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; @@ -261,4 +262,9 @@ public class MockEnvironment implements Environment { public void reportAccumulators(Map> accumulators) { // discard, this is only for testing } + + @Override + public ActorRef getJobManager() { + return ActorRef.noSender(); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/ForwardingActor.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/ForwardingActor.java new file mode 100644 index 0000000000000000000000000000000000000000..70e6f22d9dc9a111538b8a75e256d310e84264bb --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/ForwardingActor.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import akka.actor.UntypedActor; + +import java.util.concurrent.BlockingQueue; + +/** + * Actor for testing that simply puts all its messages into a + * blocking queue. + */ +class ForwardingActor extends UntypedActor { + + private final BlockingQueue queue; + + public ForwardingActor(BlockingQueue queue) { + this.queue = queue; + } + + @Override + public void onReceive(Object message) { + queue.add(message); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskExecutionStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskExecutionStateTest.java index 9d1059112f62de6baf0571ee5db9bb1898d0a822..f4c7a5717c93a9da9b4885f9da128b774aecc4bb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskExecutionStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskExecutionStateTest.java @@ -21,6 +21,8 @@ package org.apache.flink.runtime.taskmanager; import static org.junit.Assert.*; import java.io.IOException; +import java.io.PrintStream; +import java.io.PrintWriter; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -84,4 +86,35 @@ public class TaskExecutionStateTest { fail(e.getMessage()); } } + + @Test + public void hanleNonSerializableException() { + try { + @SuppressWarnings({"ThrowableInstanceNeverThrown", "serial"}) + Exception hostile = new Exception() { + // should be non serializable, because it contains the outer class reference + + @Override + public String getMessage() { + throw new RuntimeException("Cannot get Message"); + } + + @Override + public void printStackTrace(PrintStream s) { + throw new RuntimeException("Cannot print"); + } + + @Override + public void printStackTrace(PrintWriter s) { + throw new RuntimeException("Cannot print"); + } + }; + + new TaskExecutionState(new JobID(), new ExecutionAttemptID(), ExecutionState.FAILED, hostile); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java index 056ed3707003c200748d186892f7a8dbe5ec703a..d84cb37e5c05b8c43909adad09d0ab3d5ad54871 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java @@ -29,6 +29,7 @@ import akka.testkit.JavaTestKit; import akka.util.Timeout; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -56,11 +57,14 @@ import org.apache.flink.runtime.messages.TaskMessages.SubmitTask; import org.apache.flink.runtime.messages.TaskMessages.TaskOperationResult; import org.apache.flink.runtime.testingUtils.TestingTaskManager; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; -import org.apache.flink.runtime.testingUtils.TestingUtils; + import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import scala.Option; import scala.concurrent.Await; import scala.concurrent.Future; @@ -75,42 +79,61 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; @SuppressWarnings("serial") public class TaskManagerTest { - private static ActorSystem system; - - private static Timeout timeout = new Timeout(1, TimeUnit.MINUTES); + private static final Logger LOG = LoggerFactory.getLogger(TaskManagerTest.class); + + private static final Timeout timeout = new Timeout(1, TimeUnit.MINUTES); - private final FiniteDuration d = new FiniteDuration(20, TimeUnit.SECONDS); + private static final FiniteDuration d = new FiniteDuration(20, TimeUnit.SECONDS); + private static ActorSystem system; + @BeforeClass public static void setup() { - system = ActorSystem.create("TestActorSystem", TestingUtils.testConfig()); + system = AkkaUtils.createLocalActorSystem(new Configuration()); } @AfterClass public static void teardown() { JavaTestKit.shutdownActorSystem(system); - system = null; } @Test - public void testSetupTaskManager() { + public void testSubmitAndExecuteTask() { + + LOG.info( "--------------------------------------------------------------------\n" + + " Starting testSubmitAndExecuteTask() \n" + + "--------------------------------------------------------------------"); + + new JavaTestKit(system){{ - ActorRef jobManager = null; + ActorRef taskManager = null; + try { - jobManager = system.actorOf(Props.create(SimpleJobManager.class)); - - taskManager = createTaskManager(jobManager); + taskManager = createTaskManager(getTestActor(), false); + final ActorRef tmClosure = taskManager; + + // handle the registration + new Within(d) { + @Override + protected void run() { + expectMsgClass(RegistrationMessages.RegisterTaskManager.class); + + final InstanceID iid = new InstanceID(); + assertEquals(tmClosure, getLastSender()); + tmClosure.tell(new RegistrationMessages.AcknowledgeRegistration( + getTestActor(), iid, 12345), getTestActor()); + } + }; - JobID jid = new JobID(); - JobVertexID vid = new JobVertexID(); + final JobID jid = new JobID(); + final JobVertexID vid = new JobVertexID(); final ExecutionAttemptID eid = new ExecutionAttemptID(); final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7, @@ -119,13 +142,54 @@ public class TaskManagerTest { Collections.emptyList(), new ArrayList(), 0); - final ActorRef tmClosure = taskManager; + new Within(d) { @Override protected void run() { tmClosure.tell(new SubmitTask(tdd), getRef()); - expectMsgEquals(Messages.getAcknowledge()); + + // TaskManager should acknowledge the submission + // heartbeats may be interleaved + long deadline = System.currentTimeMillis() + 10000; + do { + Object message = receiveOne(d); + if (message == Messages.getAcknowledge()) { + break; + } + } while (System.currentTimeMillis() < deadline); + + // task should have switched to running + Object toRunning = new TaskMessages.UpdateTaskExecutionState( + new TaskExecutionState(jid, eid, ExecutionState.RUNNING)); + + // task should have switched to finished + Object toFinished = new TaskMessages.UpdateTaskExecutionState( + new TaskExecutionState(jid, eid, ExecutionState.FINISHED)); + + deadline = System.currentTimeMillis() + 10000; + do { + Object message = receiveOne(d); + if (message.equals(toRunning)) { + break; + } + else if (!(message instanceof TaskManagerMessages.Heartbeat)) { + fail("Unexpected message: " + message); + } + } while (System.currentTimeMillis() < deadline); + + deadline = System.currentTimeMillis() + 10000; + do { + Object message = receiveOne(d); + if (message.equals(toFinished)) { + break; + } + else if (!(message instanceof TaskManagerMessages.Heartbeat)) { + fail("Unexpected message: " + message); + } + } while (System.currentTimeMillis() < deadline); + + } }; } @@ -138,22 +202,24 @@ public class TaskManagerTest { if (taskManager != null) { taskManager.tell(Kill.getInstance(), ActorRef.noSender()); } - if (jobManager != null) { - jobManager.tell(Kill.getInstance(), ActorRef.noSender()); - } } }}; } @Test public void testJobSubmissionAndCanceling() { + + LOG.info( "--------------------------------------------------------------------\n" + + " Starting testJobSubmissionAndCanceling() \n" + + "--------------------------------------------------------------------"); + new JavaTestKit(system){{ ActorRef jobManager = null; ActorRef taskManager = null; try { jobManager = system.actorOf(Props.create(SimpleJobManager.class)); - taskManager = createTaskManager(jobManager); + taskManager = createTaskManager(jobManager, true); final JobID jid1 = new JobID(); final JobID jid2 = new JobID(); @@ -274,6 +340,11 @@ public class TaskManagerTest { @Test public void testGateChannelEdgeMismatch() { + + LOG.info( "--------------------------------------------------------------------\n" + + " Starting testGateChannelEdgeMismatch() \n" + + "--------------------------------------------------------------------"); + new JavaTestKit(system){{ ActorRef jobManager = null; @@ -281,7 +352,7 @@ public class TaskManagerTest { try { jobManager = system.actorOf(Props.create(SimpleJobManager.class)); - taskManager = createTaskManager(jobManager); + taskManager = createTaskManager(jobManager, true); final ActorRef tm = taskManager; final JobID jid = new JobID(); @@ -353,6 +424,11 @@ public class TaskManagerTest { @Test public void testRunJobWithForwardChannel() { + + LOG.info( "--------------------------------------------------------------------\n" + + " Starting testRunJobWithForwardChannel() \n" + + "--------------------------------------------------------------------"); + new JavaTestKit(system){{ ActorRef jobManager = null; @@ -368,7 +444,7 @@ public class TaskManagerTest { jobManager = system.actorOf(Props.create(new SimpleLookupJobManagerCreator())); - taskManager = createTaskManager(jobManager); + taskManager = createTaskManager(jobManager, true); final ActorRef tm = taskManager; IntermediateResultPartitionID partitionId = new IntermediateResultPartitionID(); @@ -470,6 +546,10 @@ public class TaskManagerTest { @Test public void testCancellingDependentAndStateUpdateFails() { + LOG.info( "--------------------------------------------------------------------\n" + + " Starting testCancellingDependentAndStateUpdateFails() \n" + + "--------------------------------------------------------------------"); + // this tests creates two tasks. the sender sends data, and fails to send the // state update back to the job manager // the second one blocks to be canceled @@ -491,7 +571,7 @@ public class TaskManagerTest { new SimpleLookupFailingUpdateJobManagerCreator(eid2) ) ); - taskManager = createTaskManager(jobManager); + taskManager = createTaskManager(jobManager, true); final ActorRef tm = taskManager; IntermediateResultPartitionID partitionId = new IntermediateResultPartitionID(); @@ -676,7 +756,7 @@ public class TaskManagerTest { } } - public static ActorRef createTaskManager(ActorRef jobManager) { + public static ActorRef createTaskManager(ActorRef jobManager, boolean waitForRegistration) { ActorRef taskManager = null; try { Configuration cfg = new Configuration(); @@ -695,16 +775,18 @@ public class TaskManagerTest { fail("Could not create test TaskManager: " + e.getMessage()); } - Future response = Patterns.ask(taskManager, - TaskManagerMessages.getNotifyWhenRegisteredAtJobManagerMessage(), timeout); - - try { - FiniteDuration d = new FiniteDuration(100, TimeUnit.SECONDS); - Await.ready(response, d); - } - catch (Exception e) { - e.printStackTrace(); - fail("Exception while waiting for the task manager registration: " + e.getMessage()); + if (waitForRegistration) { + Future response = Patterns.ask(taskManager, + TaskManagerMessages.getNotifyWhenRegisteredAtJobManagerMessage(), timeout); + + try { + FiniteDuration d = new FiniteDuration(100, TimeUnit.SECONDS); + Await.ready(response, d); + } + catch (Exception e) { + e.printStackTrace(); + fail("Exception while waiting for the task manager registration: " + e.getMessage()); + } } return taskManager; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index 3a8fcd893c84187f97a8f2694fcb3a4349ee05ed..4492372191177a0736a9bca4c8d0395e18cf488e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -19,72 +19,277 @@ package org.apache.flink.runtime.taskmanager; import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Kill; +import akka.actor.Props; + import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.execution.RuntimeEnvironment; +import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.MockNetworkEnvironment; import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; +import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memorymanager.MemoryManager; -import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.runtime.messages.TaskMessages; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; +import scala.concurrent.duration.FiniteDuration; -import java.util.ArrayList; import java.util.Collections; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Mockito.doNothing; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +/** + * Tests for the Task, which make sure that correct state transitions happen, + * and failures are correctly handled. + * + * All tests here have a set of mock actors for TaskManager, JobManager, and + * execution listener, which simply put the messages in a queue to be picked + * up by the test and validated. + */ public class TaskTest { + + private static ActorSystem actorSystem; + + private static OneShotLatch awaitLatch; + private static OneShotLatch triggerLatch; + + private ActorRef taskManagerMock; + private ActorRef jobManagerMock; + private ActorRef listenerActor; + + private BlockingQueue taskManagerMessages; + private BlockingQueue jobManagerMessages; + private BlockingQueue listenerMessages; + + // ------------------------------------------------------------------------ + // Init & Shutdown + // ------------------------------------------------------------------------ + @BeforeClass + public static void startActorSystem() { + actorSystem = AkkaUtils.createLocalActorSystem(new Configuration()); + } + + @AfterClass + public static void shutdown() { + actorSystem.shutdown(); + actorSystem.awaitTermination(); + } + + @Before + public void createQueuesAndActors() { + taskManagerMessages = new LinkedBlockingQueue(); + jobManagerMessages = new LinkedBlockingQueue(); + listenerMessages = new LinkedBlockingQueue(); + taskManagerMock = actorSystem.actorOf(Props.create(ForwardingActor.class, taskManagerMessages)); + jobManagerMock = actorSystem.actorOf(Props.create(ForwardingActor.class, jobManagerMessages)); + listenerActor = actorSystem.actorOf(Props.create(ForwardingActor.class, listenerMessages)); + + awaitLatch = new OneShotLatch(); + triggerLatch = new OneShotLatch(); + } + + @After + public void clearActorsAndMessages() { + jobManagerMessages = null; + taskManagerMessages = null; + listenerMessages = null; + taskManagerMock.tell(Kill.getInstance(), ActorRef.noSender()); + jobManagerMock.tell(Kill.getInstance(), ActorRef.noSender()); + listenerActor.tell(Kill.getInstance(), ActorRef.noSender()); + } + + // ------------------------------------------------------------------------ + // Tests + // ------------------------------------------------------------------------ + @Test - public void testTaskStates() { + public void testRegularExecution() { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); + Task task = createTask(TestInvokableCorrect.class); - final RuntimeEnvironment env = mock(RuntimeEnvironment.class); + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); - task.setEnvironment(env); + task.registerExecutionListener(listenerActor); - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); + // go into the run method. we should switch to DEPLOYING, RUNNING, then + // FINISHED, and all should be good + task.run(); - // cancel - task.cancelExecution(); - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + // verify final state + assertEquals(ExecutionState.FINISHED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - // cannot go into running or finished state + // verify listener messages + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.FINISHED, task, false); - assertFalse(task.startExecution()); - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + // make sure that the TaskManager received an message to unregister the task + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCancelRightAway() { + try { + Task task = createTask(TestInvokableCorrect.class); + task.cancelExecution(); + + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - assertFalse(task.markAsFinished()); + task.run(); + + // verify final state assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + validateUnregisterTask(task.getExecutionId()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testFailExternallyRightAway() { + try { + Task task = createTask(TestInvokableCorrect.class); + task.failExternally(new Exception("fail externally")); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + + task.run(); + + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + validateUnregisterTask(task.getExecutionId()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testLibraryCacheRegistrationFailed() { + try { + Task task = createTask(TestInvokableCorrect.class, mock(LibraryCacheManager.class)); + + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + task.registerExecutionListener(listenerActor); + + // should fail + task.run(); + + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertTrue(task.getFailureCause().getMessage().contains("classloader")); + + // verify listener messages + validateListenerMessage(ExecutionState.FAILED, task, true); + + // make sure that the TaskManager received an message to unregister the task + validateUnregisterTask(task.getExecutionId()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testExecutionFailsInNetworkRegistration() { + try { + // mock a working library cache + LibraryCacheManager libCache = mock(LibraryCacheManager.class); + when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); + + // mock a network manager that rejects registration + ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + NetworkEnvironment network = mock(NetworkEnvironment.class); + when(network.getPartitionManager()).thenReturn(partitionManager); + when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); + when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); + + Task task = createTask(TestInvokableCorrect.class, libCache, network); + + task.registerExecutionListener(listenerActor); + + task.run(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("buffers")); - task.markFailed(new Exception("test")); - assertTrue(ExecutionState.CANCELED == task.getExecutionState()); + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.FAILED, task, true); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } - verify(task).unregisterTask(); + @Test + public void testInvokableInstantiationFailed() { + try { + Task task = createTask(InvokableNonInstantiable.class); + task.registerExecutionListener(listenerActor); + + task.run(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.FAILED, task, true); } catch (Exception e) { e.printStackTrace(); @@ -93,48 +298,19 @@ public class TaskTest { } @Test - public void testTaskStartFinish() { + public void testExecutionFailsInRegisterInputOutput() { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); - - final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); - - final AtomicReference error = new AtomicReference(); - - Thread operation = new Thread() { - @Override - public void run() { - try { - assertTrue(task.markAsFinished()); - } - catch (Throwable t) { - error.set(t); - } - } - }; - - final RuntimeEnvironment env = mock(RuntimeEnvironment.class); - when(env.getExecutingThread()).thenReturn(operation); - - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); - - // start the execution - task.setEnvironment(env); - task.startExecution(); - - // wait for the execution to be finished - operation.join(); - - if (error.get() != null) { - ExceptionUtils.rethrow(error.get()); - } - - assertEquals(ExecutionState.FINISHED, task.getExecutionState()); + Task task = createTask(InvokableWithExceptionInRegisterInOut.class); + task.registerExecutionListener(listenerActor); - verify(task).unregisterTask(); + task.run(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("registerInputOutput")); + + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.FAILED, task, true); } catch (Exception e) { e.printStackTrace(); @@ -143,48 +319,22 @@ public class TaskTest { } @Test - public void testTaskFailesInRunning() { + public void testExecutionFailsInInvoke() { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); - - final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); - - final AtomicReference error = new AtomicReference(); - - Thread operation = new Thread() { - @Override - public void run() { - try { - task.markFailed(new Exception("test exception message")); - } - catch (Throwable t) { - error.set(t); - } - } - }; - - final RuntimeEnvironment env = mock(RuntimeEnvironment.class); - when(env.getExecutingThread()).thenReturn(operation); + Task task = createTask(InvokableWithExceptionInInvoke.class); + task.registerExecutionListener(listenerActor); - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); - - // start the execution - task.setEnvironment(env); - task.startExecution(); - - // wait for the execution to be finished - operation.join(); - - if (error.get() != null) { - ExceptionUtils.rethrow(error.get()); - } - - // make sure the final state is correct and the task manager knows the changes + task.run(); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); - verify(task).unregisterTask(); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); + + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); + + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.FAILED, task, true); } catch (Exception e) { e.printStackTrace(); @@ -193,59 +343,180 @@ public class TaskTest { } @Test - public void testTaskCanceledInRunning() { + public void testCancelDuringRegisterInputOutput() { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); - - final Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); - - final AtomicReference error = new AtomicReference(); - - // latches to create a deterministic order of events - final OneShotLatch toRunning = new OneShotLatch(); - final OneShotLatch afterCanceling = new OneShotLatch(); - - Thread operation = new Thread() { - @Override - public void run() { - try { - toRunning.trigger(); - afterCanceling.await(); - assertFalse(task.markAsFinished()); - task.cancelingDone(); - } - catch (Throwable t) { - error.set(t); - } - } - }; + Task task = createTask(InvokableBlockingInRegisterInOut.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); - final RuntimeEnvironment env = mock(RuntimeEnvironment.class); - when(env.getExecutingThread()).thenReturn(operation); + // wait till the task is in regInOut + awaitLatch.await(); - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); + task.cancelExecution(); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); + triggerLatch.trigger(); - // start the execution - task.setEnvironment(env); - task.startExecution(); + task.getExecutingThread().join(); + + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - toRunning.await(); + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.CANCELING, task, false); + validateListenerMessage(ExecutionState.CANCELED, task, false); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testFailDuringRegisterInputOutput() { + try { + Task task = createTask(InvokableBlockingInRegisterInOut.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); + + // wait till the task is in regInOut + awaitLatch.await(); + + task.failExternally(new Exception("test")); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + triggerLatch.trigger(); + + task.getExecutingThread().join(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); + + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.FAILED, task, true); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCancelDuringInvoke() { + try { + Task task = createTask(InvokableBlockingInInvoke.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); + + // wait till the task is in invoke + awaitLatch.await(); + task.cancelExecution(); - afterCanceling.trigger(); + assertTrue(task.getExecutionState() == ExecutionState.CANCELING || + task.getExecutionState() == ExecutionState.CANCELED); + + task.getExecutingThread().join(); + + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); - // wait for the execution to be finished - operation.join(); + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.CANCELING, task, false); + validateListenerMessage(ExecutionState.CANCELED, task, false); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testFailExternallyDuringInvoke() { + try { + Task task = createTask(InvokableBlockingInInvoke.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); + + // wait till the task is in regInOut + awaitLatch.await(); + + task.failExternally(new Exception("test")); + assertTrue(task.getExecutionState() == ExecutionState.FAILED); + + task.getExecutingThread().join(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); + + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); + + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.FAILED, task, true); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCanceledAfterExecutionFailedInRegInOut() { + try { + Task task = createTask(InvokableWithExceptionInRegisterInOut.class); + task.registerExecutionListener(listenerActor); + + task.run(); - if (error.get() != null) { - ExceptionUtils.rethrow(error.get()); - } + // this should not overwrite the failure state + task.cancelExecution(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("registerInputOutput")); + + validateUnregisterTask(task.getExecutionId()); + validateListenerMessage(ExecutionState.FAILED, task, true); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCanceledAfterExecutionFailedInInvoke() { + try { + Task task = createTask(InvokableWithExceptionInInvoke.class); + task.registerExecutionListener(listenerActor); + + task.run(); + + // this should not overwrite the failure state + task.cancelExecution(); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); + + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); - // make sure the final state is correct and the task manager knows the changes - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - verify(task).unregisterTask(); + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.FAILED, task, true); } catch (Exception e) { e.printStackTrace(); @@ -254,79 +525,205 @@ public class TaskTest { } @Test - public void testTaskWithEnvironment() { + public void testExecutionFailesAfterCanceling() { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); - - TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7, - new Configuration(), new Configuration(), TestInvokableCorrect.class.getName(), - Collections.emptyList(), - Collections.emptyList(), - new ArrayList(), 0); - - Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); + Task task = createTask(InvokableWithExceptionOnTrigger.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); + + // wait till the task is in invoke + awaitLatch.await(); + + task.cancelExecution(); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - RuntimeEnvironment env = new RuntimeEnvironment(mock(ActorRef.class), task, tdd, getClass().getClassLoader(), - mock(MemoryManager.class), mock(IOManager.class), mock(InputSplitProvider.class), - new BroadcastVariableManager(), MockNetworkEnvironment.getMock()); + // this causes an exception + triggerLatch.trigger(); + + task.getExecutingThread().join(); - task.setEnvironment(env); + // we should still be in state canceled + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); + + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.CANCELING, task, false); + validateListenerMessage(ExecutionState.CANCELED, task, false); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testExecutionFailesAfterTaskMarkedFailed() { + try { + Task task = createTask(InvokableWithExceptionOnTrigger.class); + task.registerExecutionListener(listenerActor); + + // run the task asynchronous + task.startTaskThread(); + + // wait till the task is in invoke + awaitLatch.await(); + + task.failExternally(new Exception("external")); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + + // this causes an exception + triggerLatch.trigger(); + + task.getExecutingThread().join(); - task.startExecution(); - task.getEnvironment().getExecutingThread().join(); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("external")); - assertEquals(ExecutionState.FINISHED, task.getExecutionState()); + validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); + validateUnregisterTask(task.getExecutionId()); - verify(task).unregisterTask(); + validateListenerMessage(ExecutionState.RUNNING, task, false); + validateListenerMessage(ExecutionState.FAILED, task, true); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } + + private Task createTask(Class invokable) { + LibraryCacheManager libCache = mock(LibraryCacheManager.class); + when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); + return createTask(invokable, libCache); + } + + private Task createTask(Class invokable, + LibraryCacheManager libCache) { + + ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + NetworkEnvironment network = mock(NetworkEnvironment.class); + when(network.getPartitionManager()).thenReturn(partitionManager); + when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); + when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + + return createTask(invokable, libCache, network); + } - @Test - public void testTaskWithEnvironmentAndException() { + private Task createTask(Class invokable, + LibraryCacheManager libCache, + NetworkEnvironment networkEnvironment) { + + TaskDeploymentDescriptor tdd = createTaskDeploymentDescriptor(invokable); + + return new Task(tdd, + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + mock(BroadcastVariableManager.class), + taskManagerMock, jobManagerMock, + new FiniteDuration(60, TimeUnit.SECONDS), + libCache, + mock(FileCache.class)); + } + + private TaskDeploymentDescriptor createTaskDeploymentDescriptor(Class invokable) { + return new TaskDeploymentDescriptor( + new JobID(), new JobVertexID(), new ExecutionAttemptID(), + "Test Task", 0, 1, + new Configuration(), new Configuration(), + invokable.getName(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + 0); + } + + // ------------------------------------------------------------------------ + // Validation Methods + // ------------------------------------------------------------------------ + + private void validateUnregisterTask(ExecutionAttemptID id) { try { - final JobID jid = new JobID(); - final JobVertexID vid = new JobVertexID(); - final ExecutionAttemptID eid = new ExecutionAttemptID(); + // we may have to wait for a bit to give the actors time to receive the message + // and put it into the queue + Object rawMessage = taskManagerMessages.poll(10, TimeUnit.SECONDS); - TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, vid, eid, "TestTask", 2, 7, - new Configuration(), new Configuration(), TestInvokableWithException.class.getName(), - Collections.emptyList(), - Collections.emptyList(), - new ArrayList(), 0); + assertNotNull("There is no additional TaskManager message", rawMessage); + assertTrue("TaskManager message is not 'UnregisterTask'", rawMessage instanceof TaskMessages.TaskInFinalState); - Task task = spy(new Task(jid, vid, 2, 7, eid, "TestTask", ActorRef.noSender())); - doNothing().when(task).unregisterTask(); - - RuntimeEnvironment env = new RuntimeEnvironment(mock(ActorRef.class), task, tdd, getClass().getClassLoader(), - mock(MemoryManager.class), mock(IOManager.class), mock(InputSplitProvider.class), - new BroadcastVariableManager(), MockNetworkEnvironment.getMock()); + TaskMessages.TaskInFinalState message = (TaskMessages.TaskInFinalState) rawMessage; + assertEquals(id, message.executionID()); + } + catch (InterruptedException e) { + fail("interrupted"); + } + } - task.setEnvironment(env); + private void validateTaskManagerStateChange(ExecutionState state, Task task, boolean hasError) { + try { + // we may have to wait for a bit to give the actors time to receive the message + // and put it into the queue + Object rawMessage = taskManagerMessages.poll(10, TimeUnit.SECONDS); + + assertNotNull("There is no additional TaskManager message", rawMessage); + assertTrue("TaskManager message is not 'UpdateTaskExecutionState'", + rawMessage instanceof TaskMessages.UpdateTaskExecutionState); - assertEquals(ExecutionState.DEPLOYING, task.getExecutionState()); + TaskMessages.UpdateTaskExecutionState message = + (TaskMessages.UpdateTaskExecutionState) rawMessage; - task.startExecution(); - task.getEnvironment().getExecutingThread().join(); + TaskExecutionState taskState = message.taskExecutionState(); + + assertEquals(task.getJobID(), taskState.getJobID()); + assertEquals(task.getExecutionId(), taskState.getID()); + assertEquals(state, taskState.getExecutionState()); + + if (hasError) { + assertNotNull(taskState.getError(getClass().getClassLoader())); + } else { + assertNull(taskState.getError(getClass().getClassLoader())); + } + } + catch (InterruptedException e) { + fail("interrupted"); + } + } + + private void validateListenerMessage(ExecutionState state, Task task, boolean hasError) { + try { + // we may have to wait for a bit to give the actors time to receive the message + // and put it into the queue + TaskMessages.UpdateTaskExecutionState message = + (TaskMessages.UpdateTaskExecutionState) listenerMessages.poll(10, TimeUnit.SECONDS); + assertNotNull("There is no additional listener message", message); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + TaskExecutionState taskState = message.taskExecutionState(); - verify(task).unregisterTask(); + assertEquals(task.getJobID(), taskState.getJobID()); + assertEquals(task.getExecutionId(), taskState.getID()); + assertEquals(state, taskState.getExecutionState()); + + if (hasError) { + assertNotNull(taskState.getError(getClass().getClassLoader())); + } else { + assertNull(taskState.getError(getClass().getClassLoader())); + } } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + catch (InterruptedException e) { + fail("interrupted"); } } + // -------------------------------------------------------------------------------------------- + // Mock invokable code // -------------------------------------------------------------------------------------------- public static final class TestInvokableCorrect extends AbstractInvokable { @@ -336,16 +733,93 @@ public class TaskTest { @Override public void invoke() {} + + @Override + public void cancel() throws Exception { + fail("This should not be called"); + } + } + + public static final class InvokableWithExceptionInRegisterInOut extends AbstractInvokable { + + @Override + public void registerInputOutput() { + throw new RuntimeException("test"); + } + + @Override + public void invoke() {} } - public static final class TestInvokableWithException extends AbstractInvokable { + public static final class InvokableWithExceptionInInvoke extends AbstractInvokable { + + @Override + public void registerInputOutput() {} + + @Override + public void invoke() throws Exception { + throw new Exception("test"); + } + } + + public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable { + + @Override + public void registerInputOutput() {} + + @Override + public void invoke() { + awaitLatch.trigger(); + + // make sure that the interrupt call does not + // grab us out of the lock early + while (true) { + try { + triggerLatch.await(); + break; + } + catch (InterruptedException e) { + // fall through the loop + } + } + + throw new RuntimeException("test"); + } + } + + public static abstract class InvokableNonInstantiable extends AbstractInvokable {} + + public static final class InvokableBlockingInRegisterInOut extends AbstractInvokable { + + @Override + public void registerInputOutput() { + awaitLatch.trigger(); + + try { + triggerLatch.await(); + } + catch (InterruptedException e) { + throw new RuntimeException(); + } + } + + @Override + public void invoke() {} + } + + public static final class InvokableBlockingInInvoke extends AbstractInvokable { @Override public void registerInputOutput() {} @Override public void invoke() throws Exception { - throw new Exception("test exception"); + awaitLatch.trigger(); + + // block forever + synchronized (this) { + wait(); + } } } } diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala index bb0c1f90b25b22bc12100a1c6c1c4c4223eb2e76..53182540546fa9e70d8fcd91de9f30598e4fdb99 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala @@ -27,7 +27,7 @@ import org.apache.flink.runtime.io.disk.iomanager.IOManager import org.apache.flink.runtime.io.network.NetworkEnvironment import org.apache.flink.runtime.memorymanager.DefaultMemoryManager import org.apache.flink.runtime.messages.Messages.Disconnect -import org.apache.flink.runtime.messages.TaskMessages.{UpdateTaskExecutionState, UnregisterTask} +import org.apache.flink.runtime.messages.TaskMessages.{TaskInFinalState, UpdateTaskExecutionState} import org.apache.flink.runtime.taskmanager.{TaskManagerConfiguration, TaskManager} import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobRemoved import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect @@ -95,8 +95,8 @@ class TestingTaskManager(config: TaskManagerConfiguration, } } - case UnregisterTask(executionID) => - super.receiveWithLogMessages(UnregisterTask(executionID)) + case TaskInFinalState(executionID) => + super.receiveWithLogMessages(TaskInFinalState(executionID)) waitForRemoval.remove(executionID) match { case Some(actors) => for(actor <- actors) actor ! true case None =>