未验证 提交 404a4a6a 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] add completion_nofifier (#38447)

* add completion_nofifier

* fix bug

* unregist event waiter
上级 1db61c3e
...@@ -30,6 +30,7 @@ DECLARE_bool(check_nan_inf); ...@@ -30,6 +30,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
constexpr const char* kExceptionCaught = "ExceptionCaught"; constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -49,6 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -49,6 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
gc_.reset(new InterpreterCoreGarbageCollector()); gc_.reset(new InterpreterCoreGarbageCollector());
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
create_local_scope_ = FLAGS_new_executor_use_local_scope; create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) { if (FLAGS_new_executor_use_local_scope) {
...@@ -69,6 +71,9 @@ InterpreterCore::~InterpreterCore() { ...@@ -69,6 +71,9 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread // cancle gc's thread
gc_.reset(nullptr); gc_.reset(nullptr);
exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();
async_work_queue_.reset(nullptr); async_work_queue_.reset(nullptr);
} }
...@@ -417,7 +422,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -417,7 +422,7 @@ void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
async_work_queue_->PrepareAtomicDeps(dependecy_count_); async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo()); async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
op_run_number_ = 0; unfinished_op_numer_ = vec_instr.size();
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -436,12 +441,6 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -436,12 +441,6 @@ void InterpreterCore::ExecuteInstructionList(
async_work_queue_->Cancel(); async_work_queue_->Cancel();
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(),
platform::errors::Fatal(
"Required op_run_number == %d, but received op_run_number = %d.",
vec_instr.size(), op_run_number_.load()));
} }
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
...@@ -539,8 +538,15 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -539,8 +538,15 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return; return;
} }
VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_;
if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) ==
1)) {
if (completion_notifier_ != nullptr) {
completion_notifier_->NotifyEvent();
}
}
interpreter::RecordEvent(instr_node, place_); interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed);
RunNextInstructions(instr_node, &ready_ops); RunNextInstructions(instr_node, &ready_ops);
} }
......
...@@ -101,7 +101,7 @@ class InterpreterCore { ...@@ -101,7 +101,7 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::atomic<size_t> op_run_number_{0}; std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
...@@ -109,6 +109,7 @@ class InterpreterCore { ...@@ -109,6 +109,7 @@ class InterpreterCore {
std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_; std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
......
...@@ -60,13 +60,13 @@ class AsyncWorkQueue { ...@@ -60,13 +60,13 @@ class AsyncWorkQueue {
// for execute host Kernel // for execute host Kernel
group_options.emplace_back(/*num_threads*/ host_num_threads, group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*track_task*/ true, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
// for launch device Kernel // for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1, group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*track_task*/ true, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options); queue_group_ = CreateWorkQueueGroup(group_options);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册