diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc index d334520a93f8e0cc0cfa42e31fa0aa72e714717c..7cfe28fd7616d492effe17af4023b5d48330e430 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -30,9 +30,6 @@ namespace paddle { namespace framework { namespace details { -static std::atomic exec_op_count_; -static std::atomic error_state; - BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &local_exec_scopes, @@ -125,7 +122,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( for (auto cur_op : ready_fetch_ops) { ready_ops->Push(cur_op); } - + // Atomic variable, no need to lock exec_op_count_ = 0; platform::XPUPlace cur_place; @@ -134,9 +131,8 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( while (cur_count < op_deps_.size()) { cur_count++; auto cur_op = ready_ops->Pop(); + // when execption, get cur_op == nullptr if (cur_op == nullptr) { - // sleep a while to make sure worker thread quit - sleep(10); exec_op_count_ = op_deps_.size(); break; } @@ -151,14 +147,16 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index); } } - while (exec_op_count_ < op_deps_.size()) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); }); } - - // Wait FetchOps. - ClearFetchOp(graph_, &fetch_ops); if (exception_.IsCaught()) { ExecutionFinal(&fetch_ops); } + + // Wait FetchOps. + ClearFetchOp(graph_, &fetch_ops); return fetches; } @@ -222,7 +220,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps( } } } - +// RunMultiDeviceOpAsync function is used for Communicated OPs +// like all_reduce\broadcast among multicards. void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( OpHandleBase *op, std::unordered_map *op_deps, @@ -256,10 +255,12 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } + // Atomic variable, no need to lock exec_op_count_++; + cv_.notify_all(); }); } - +// RunOpAsyncMainStream function is used for computed OPs void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps, @@ -285,7 +286,9 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } + // Atomic variable, no need to lock exec_op_count_++; + cv_.notify_all(); }); } diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h index 87c1908944e70c0dc4bfd09509b25dc3788c4cf0..b92ba7a0df0a89b4aaecdaad21e21a696fc22721 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h @@ -14,7 +14,9 @@ #pragma once #include +#include // NOLINT #include +#include // NOLINT #include #include #include @@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { ::ThreadPool prepare_pool_; ::ThreadPool multi_device_op_pool_; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic exec_op_count_; + std::atomic error_state; + void RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps,