diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc index d334520a93f8e0cc0cfa42e31fa0aa72e714717c..6ce1eac2e30d24e3752ba08d77637005f5360c01 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -30,9 +30,6 @@ namespace paddle { namespace framework { namespace details { -static std::atomic exec_op_count_; -static std::atomic error_state; - BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &local_exec_scopes, @@ -126,18 +123,21 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ready_ops->Push(cur_op); } - exec_op_count_ = 0; + { + std::lock_guard lock(mutex_); + exec_op_count_ = 0; + } platform::XPUPlace cur_place; std::size_t cur_count = 0; - while (cur_count < op_deps_.size()) { + while (cur_count < op_deps->size()) { cur_count++; auto cur_op = ready_ops->Pop(); + // when execption, get cur_op == nullptr if (cur_op == nullptr) { - // sleep a while to make sure worker thread quit - sleep(10); - exec_op_count_ = op_deps_.size(); + std::lock_guard lock(mutex_); + exec_op_count_ = op_deps->size(); break; } auto dev_ctxes_ = cur_op->DeviceContext(); @@ -151,14 +151,17 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index); } } - while (exec_op_count_ < op_deps_.size()) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return exec_op_count_ >= op_deps->size(); }); } - // Wait FetchOps. - ClearFetchOp(graph_, &fetch_ops); if (exception_.IsCaught()) { ExecutionFinal(&fetch_ops); } + + // Wait FetchOps. + ClearFetchOp(graph_, &fetch_ops); return fetches; } @@ -222,7 +225,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps( } } } - +// RunMultiDeviceOpAsync function is used for Communicated OPs +// like all_reduce\broadcast among multicards. void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( OpHandleBase *op, std::unordered_map *op_deps, @@ -256,10 +260,14 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - exec_op_count_++; + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } - +// RunOpAsyncMainStream function is used for computed OPs void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps, @@ -285,7 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - exec_op_count_++; + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h index 87c1908944e70c0dc4bfd09509b25dc3788c4cf0..5e973f13cc618a28c370b82576cc72c7fb499495 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h @@ -14,7 +14,9 @@ #pragma once #include +#include // NOLINT #include +#include // NOLINT #include #include #include @@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { ::ThreadPool prepare_pool_; ::ThreadPool multi_device_op_pool_; + std::mutex mutex_; + std::condition_variable cv_; + uint32_t exec_op_count_; + std::atomic error_state; + void RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps,