From 29467060a3c471b1f378e82f4823c244b306e7f8 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 23 Feb 2021 17:18:42 +0800 Subject: [PATCH] [cherry-pick 2.0.1] [kunlun] fix xpu bind threaded executor (#31116) * [Kunlun] Add condition_variable and notify() in BindThreadedSSAGraphExecutor (#30586) * [Kunlun] fix dead lock for exec_op_count_ (#30718) * Fix the problem that the number of ops executed by xpu is wrong (#30961) Co-authored-by: liuyuhui --- .../bind_threaded_ssa_graph_executor.cc | 42 ++++++++++++------- .../bind_threaded_ssa_graph_executor.h | 7 ++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc index d334520a93f..6ce1eac2e30 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -30,9 +30,6 @@ namespace paddle { namespace framework { namespace details { -static std::atomic exec_op_count_; -static std::atomic error_state; - BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &local_exec_scopes, @@ -126,18 +123,21 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ready_ops->Push(cur_op); } - exec_op_count_ = 0; + { + std::lock_guard lock(mutex_); + exec_op_count_ = 0; + } platform::XPUPlace cur_place; std::size_t cur_count = 0; - while (cur_count < op_deps_.size()) { + while (cur_count < op_deps->size()) { cur_count++; auto cur_op = ready_ops->Pop(); + // when execption, get cur_op == nullptr if (cur_op == nullptr) { - // sleep a while to make sure worker thread quit - sleep(10); - exec_op_count_ = op_deps_.size(); + std::lock_guard lock(mutex_); + exec_op_count_ = op_deps->size(); break; } auto dev_ctxes_ = cur_op->DeviceContext(); @@ -151,14 +151,17 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index); } } - while (exec_op_count_ < op_deps_.size()) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return exec_op_count_ >= op_deps->size(); }); } - // Wait FetchOps. - ClearFetchOp(graph_, &fetch_ops); if (exception_.IsCaught()) { ExecutionFinal(&fetch_ops); } + + // Wait FetchOps. + ClearFetchOp(graph_, &fetch_ops); return fetches; } @@ -222,7 +225,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps( } } } - +// RunMultiDeviceOpAsync function is used for Communicated OPs +// like all_reduce\broadcast among multicards. void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( OpHandleBase *op, std::unordered_map *op_deps, @@ -256,10 +260,14 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - exec_op_count_++; + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } - +// RunOpAsyncMainStream function is used for computed OPs void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps, @@ -285,7 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ready_ops->Push(nullptr); exception_.Catch(std::current_exception()); } - exec_op_count_++; + { + std::lock_guard lock(mutex_); + exec_op_count_++; + cv_.notify_all(); + } }); } diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h index 87c1908944e..5e973f13cc6 100644 --- a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h @@ -14,7 +14,9 @@ #pragma once #include +#include // NOLINT #include +#include // NOLINT #include #include #include @@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { ::ThreadPool prepare_pool_; ::ThreadPool multi_device_op_pool_; + std::mutex mutex_; + std::condition_variable cv_; + uint32_t exec_op_count_; + std::atomic error_state; + void RunOpAsyncMainStream( OpHandleBase *op, std::unordered_map *op_deps, -- GitLab