diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc index 7cfe28fd7616d492effe17af4023b5d48330e430..6d3c52dabbd0d8b0a6aab53893c3b5256f71a28e 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -122,8 +122,11 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( for (auto cur_op : ready_fetch_ops) { ready_ops->Push(cur_op); } - // Atomic variable, no need to lock - exec_op_count_ = 0; + + { + std::lock_guard lock(mutex_); + exec_op_count_ = 0; + } platform::XPUPlace cur_place; std::size_t cur_count = 0; @@ -133,6 +136,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( auto cur_op = ready_ops->Pop(); // when execption, get cur_op == nullptr if (cur_op == nullptr) { + std::lock_guard lock(mutex_); exec_op_count_ = op_deps_.size(); break; } @@ -151,6 +155,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( std::unique_lock lock(mutex_); cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); }); } + if (exception_.IsCaught()) { ExecutionFinal(&fetch_ops); } @@ -255,9 +260,11 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - // Atomic variable, no need to lock - exec_op_count_++; - cv_.notify_all(); + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } // RunOpAsyncMainStream function is used for computed OPs @@ -286,9 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - // Atomic variable, no need to lock - exec_op_count_++; - cv_.notify_all(); + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h index b92ba7a0df0a89b4aaecdaad21e21a696fc22721..5e973f13cc618a28c370b82576cc72c7fb499495 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h @@ -80,7 +80,7 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { std::mutex mutex_; std::condition_variable cv_; - std::atomic exec_op_count_; + uint32_t exec_op_count_; std::atomic error_state; void RunOpAsyncMainStream(