diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 5d437594ab754ce6ee6becea3574c58cc59cbc63..0a60488d9f629c253bf66f1704fc4cde9e9b65a0 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -28,6 +28,22 @@ limitations under the License. */ namespace paddle { namespace framework { +struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + LOG(FATAL) << "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception.\n" + "The default exception handler is LOG(FATAL)." + << ex->what(); + } + } +}; + // ThreadPool maintains a queue of tasks, and runs them using a fixed // number of threads. class ThreadPool { @@ -87,22 +103,6 @@ class ThreadPool { void Wait(); private: - struct ExceptionHandler { - mutable std::future> future_; - explicit ExceptionHandler( - std::future>&& f) - : future_(std::move(f)) {} - void operator()() const { - auto ex = this->future_.get(); - if (ex != nullptr) { - LOG(FATAL) << "The exception is thrown inside the thread pool. You " - "should use RunAndGetException to handle the exception.\n" - "The default exception handler is LOG(FATAL)." - << ex->what(); - } - } - }; - DISABLE_COPY_AND_ASSIGN(ThreadPool); // If the task queue is empty and avaialbe is equal to the number of diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 3f96ce37185bcc7aa0baa9cb9b6fada3c838c21f..d79ba6d291950e1f089eb11713bd1c3e4d154b27 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,8 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, - this] { + framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -89,8 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, - this] { + framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -133,8 +131,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -197,7 +195,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) {