未验证 提交 404a4a6a 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] add completion_nofifier (#38447)

* add completion_nofifier

* fix bug

* unregist event waiter
上级 1db61c3e
......@@ -30,6 +30,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";
namespace paddle {
namespace framework {
......@@ -49,6 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
gc_.reset(new InterpreterCoreGarbageCollector());
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
......@@ -69,6 +71,9 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);
exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();
async_work_queue_.reset(nullptr);
}
......@@ -417,7 +422,7 @@ void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) {
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
op_run_number_ = 0;
unfinished_op_numer_ = vec_instr.size();
exception_holder_.Clear();
......@@ -436,12 +441,6 @@ void InterpreterCore::ExecuteInstructionList(
async_work_queue_->Cancel();
exception_holder_.ReThrow();
}
PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(),
platform::errors::Fatal(
"Required op_run_number == %d, but received op_run_number = %d.",
vec_instr.size(), op_run_number_.load()));
}
void InterpreterCore::RunNextInstructions(
......@@ -539,8 +538,15 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return;
}
VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_;
if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) ==
1)) {
if (completion_notifier_ != nullptr) {
completion_notifier_->NotifyEvent();
}
}
interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed);
RunNextInstructions(instr_node, &ready_ops);
}
......
......@@ -101,7 +101,7 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::vector<size_t> dependecy_count_;
std::atomic<size_t> op_run_number_{0};
std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_;
StreamAnalyzer stream_analyzer_;
......@@ -109,6 +109,7 @@ class InterpreterCore {
std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
......
......@@ -60,13 +60,13 @@ class AsyncWorkQueue {
// for execute host Kernel
group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*track_task*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册