diff --git a/paddle/fluid/framework/details/exception_holder.h b/paddle/fluid/framework/details/exception_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..6e302a29233b96451df14b4685911be1cd87c1ab --- /dev/null +++ b/paddle/fluid/framework/details/exception_holder.h @@ -0,0 +1,83 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace details { + +class ExceptionHolder { + public: + void Catch(const platform::EnforceNotMet& exp) { + std::lock_guard lock(mu_); + exception_.reset(new platform::EnforceNotMet(exp)); + type_ = kEnforceNotMet; + } + + void Catch(const platform::EOFException& exp) { + std::lock_guard lock(mu_); + // EOFException will not cover up existing EnforceNotMet. + if (exception_.get() == nullptr) { + exception_.reset(new platform::EOFException(exp)); + type_ = kEOF; + } + } + + bool ExceptionCatched() const { + std::lock_guard lock(mu_); + return exception_.get() != nullptr; + } + + void Throw() { + std::lock_guard lock(mu_); + switch (type_) { + case kNone: + break; + case kEnforceNotMet: { + auto e = *static_cast(exception_.get()); + throw e; + break; + } + case kEOF: { + auto e = *static_cast(exception_.get()); + throw e; + break; + } + default: + LOG(FATAL) << "Unknown exception."; + } + exception_.reset(); + type_ = kNone; + } + + void Clear() { + std::lock_guard lock(mu_); + exception_.reset(); + type_ = kNone; + } + + private: + enum ExceptionType { kNone, kEnforceNotMet, kEOF }; + ExceptionType type_{kNone}; + + std::unique_ptr exception_; + mutable std::mutex mu_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c19f74476f9a1498a7d61f5faf204e9966aea155..00f1f262a6505881c72adc451b95077aa3719872 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -83,7 +83,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Clean run context run_op_futures_.clear(); - exception_.reset(); + exception_holder_.Clear(); // Step 3. Execution while (!pending_vars.empty()) { @@ -103,23 +103,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - std::unique_lock l(exception_mu_); - if (exception_) { - l.unlock(); + if (exception_holder_.ExceptionCatched()) { for (auto &run_op_future : run_op_futures_) { run_op_future.wait(); } - l.lock(); - std::exception *exp = exception_.get(); - if (dynamic_cast(exp)) { - auto e = *static_cast(exp); - throw e; - } else if (dynamic_cast(exp)) { - auto e = *static_cast(exp); - throw e; - } else { - LOG(FATAL) << "Unknown exception."; - } + exception_holder_.Throw(); } else { continue; } @@ -229,14 +217,9 @@ void ThreadedSSAGraphExecutor::RunOp( ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; } catch (platform::EOFException ex) { - std::lock_guard l(exception_mu_); - // EOFException will not cover up existing EnforceNotMet. - if (exception_.get() == nullptr) { - exception_.reset(new platform::EOFException(ex)); - } + exception_holder_.Catch(ex); } catch (platform::EnforceNotMet ex) { - std::lock_guard l(exception_mu_); - exception_.reset(new platform::EnforceNotMet(ex)); + exception_holder_.Catch(ex); } catch (...) { LOG(FATAL) << "Unknown exception catched"; } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 3d67daa45e20fdea52689684397ad01f2f4cd783..4f3e5a6288be775990b64e86ce29271961effbe1 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -24,6 +24,7 @@ #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" @@ -58,8 +59,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::vector places_; platform::DeviceContextPool fetch_ctxs_; - std::mutex exception_mu_; - std::unique_ptr exception_; + ExceptionHolder exception_holder_; std::atomic running_ops_; void InsertPendingOp(std::unordered_map *pending_ops,