From 404a4a6a71b63b4949097f4fd039558a42ea640e Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 28 Dec 2021 13:16:16 +0800 Subject: [PATCH] [new-exec] add completion_nofifier (#38447) * add completion_nofifier * fix bug * unregist event waiter --- .../framework/new_executor/interpretercore.cc | 22 ++++++++++++------- .../framework/new_executor/interpretercore.h | 3 ++- .../new_executor/interpretercore_util.h | 4 ++-- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 5a4caf6af4..94a27294e8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -30,6 +30,7 @@ DECLARE_bool(check_nan_inf); DECLARE_bool(benchmark); constexpr const char* kExceptionCaught = "ExceptionCaught"; +constexpr const char* kTaskCompletion = "TaskCompletion"; namespace paddle { namespace framework { @@ -49,6 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, gc_.reset(new InterpreterCoreGarbageCollector()); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); + completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); create_local_scope_ = FLAGS_new_executor_use_local_scope; if (FLAGS_new_executor_use_local_scope) { @@ -69,6 +71,9 @@ InterpreterCore::~InterpreterCore() { // cancle gc's thread gc_.reset(nullptr); + exception_notifier_->UnregisterEvent(); + completion_notifier_->UnregisterEvent(); + async_work_queue_.reset(nullptr); } @@ -417,7 +422,7 @@ void InterpreterCore::ExecuteInstructionList( const std::vector& vec_instr) { async_work_queue_->PrepareAtomicDeps(dependecy_count_); async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo()); - op_run_number_ = 0; + unfinished_op_numer_ = vec_instr.size(); exception_holder_.Clear(); @@ -436,12 +441,6 @@ void InterpreterCore::ExecuteInstructionList( async_work_queue_->Cancel(); exception_holder_.ReThrow(); } - - PADDLE_ENFORCE_EQ( - op_run_number_.load(), vec_instr.size(), - platform::errors::Fatal( - "Required op_run_number == %d, but received op_run_number = %d.", - vec_instr.size(), op_run_number_.load())); } void InterpreterCore::RunNextInstructions( @@ -539,8 +538,15 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { return; } + VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_; + if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) == + 1)) { + if (completion_notifier_ != nullptr) { + completion_notifier_->NotifyEvent(); + } + } + interpreter::RecordEvent(instr_node, place_); - op_run_number_.fetch_add(1, std::memory_order_relaxed); RunNextInstructions(instr_node, &ready_ops); } diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 93ac7c0294..e6e6b7cdc3 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -101,7 +101,7 @@ class InterpreterCore { std::vector vec_instruction_; // deconstruct before OpFuncNode std::vector dependecy_count_; - std::atomic op_run_number_{0}; + std::atomic unfinished_op_numer_{0}; std::vector> input_var2op_info_; StreamAnalyzer stream_analyzer_; @@ -109,6 +109,7 @@ class InterpreterCore { std::unique_ptr async_work_queue_; details::ExceptionHolder exception_holder_; std::shared_ptr exception_notifier_{nullptr}; + std::shared_ptr completion_notifier_{nullptr}; std::unique_ptr gc_; std::vector gc_event_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 14c27c94f8..5f403613c6 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -60,13 +60,13 @@ class AsyncWorkQueue { // for execute host Kernel group_options.emplace_back(/*num_threads*/ host_num_threads, /*allow_spinning*/ true, - /*track_task*/ true, + /*track_task*/ false, /*detached*/ true, /*events_waiter*/ waiter); // for launch device Kernel group_options.emplace_back(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true, + /*track_task*/ false, /*detached*/ true, /*events_waiter*/ waiter); queue_group_ = CreateWorkQueueGroup(group_options); -- GitLab