diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index dcc6bb210ceda670ae08694803c31dd44dc90316..ea3e7dd411929140ea7b6b4c6c9ec32d89e52ab0 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -39,9 +39,11 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, : place_(place), block_(block), global_scope_(global_scope), - stream_analyzer_(place), - async_work_queue_(kHostNumThreads, &main_thread_blocker_) { + stream_analyzer_(place) { is_build_ = false; + async_work_queue_.reset( + new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_)); + gc_.reset(new InterpreterCoreGarbageCollector()); feed_names_ = feed_names; @@ -55,6 +57,13 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, // convert to run graph } +InterpreterCore::~InterpreterCore() { + // cancle gc's thread + gc_.reset(nullptr); + + async_work_queue_.reset(nullptr); +} + paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_tensors) { auto FeedInput = [&] { @@ -349,16 +358,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::ExecuteInstructionList( const std::vector& vec_instr) { - async_work_queue_.PrepareAtomicDeps(dependecy_count_); - async_work_queue_.PrepareAtomicVarRef(vec_meta_info_); + async_work_queue_->PrepareAtomicDeps(dependecy_count_); + async_work_queue_->PrepareAtomicVarRef(vec_meta_info_); op_run_number_ = 0; exception_holder_.Clear(); for (size_t i = 0; i < dependecy_count_.size(); ++i) { if (dependecy_count_[i] == 0) { - async_work_queue_.AddTask(vec_instr.at(i).KernelType(), - [&, i] { RunInstructionAsync(i); }); + async_work_queue_->AddTask(vec_instr.at(i).KernelType(), + [&, i] { RunInstructionAsync(i); }); } } @@ -380,7 +389,7 @@ void InterpreterCore::ExecuteInstructionList( void InterpreterCore::RunNextInstructions( const Instruction& instr, std::queue* reserved_next_ops) { auto& next_instr = instr.NextInstructions(); - auto& atomic_deps = async_work_queue_.AtomicDeps(); + auto& atomic_deps = async_work_queue_->AtomicDeps(); auto IsReady = [&](size_t next_id) { return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1; }; @@ -389,7 +398,7 @@ void InterpreterCore::RunNextInstructions( // move all sync_ops into other threads for (auto next_id : next_instr.SyncRunIds()) { if (IsReady(next_id)) { - async_work_queue_.AddTask( + async_work_queue_->AddTask( vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } @@ -409,7 +418,7 @@ void InterpreterCore::RunNextInstructions( // move async_ops into async_thread for (auto next_id : next_instr.EventRunIds()) { if (IsReady(next_id)) { - async_work_queue_.AddTask( + async_work_queue_->AddTask( vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } @@ -425,7 +434,7 @@ void InterpreterCore::RunNextInstructions( continue; } // move rest ops into other threads - async_work_queue_.AddTask( + async_work_queue_->AddTask( vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } @@ -483,7 +492,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { void InterpreterCore::CheckGC(const Instruction& instr) { size_t instr_id = instr.Id(); auto& var_scope = *global_scope_; - auto& atomic_var_ref = async_work_queue_.AtomicVarRef(); + auto& atomic_var_ref = async_work_queue_->AtomicVarRef(); for (auto var_id : instr.GCCheckVars()) { bool is_ready = @@ -493,8 +502,8 @@ void InterpreterCore::CheckGC(const Instruction& instr) { continue; } if (is_ready) { - gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id), - &instr.DeviceContext()); + gc_->Add(var_scope.Var(var_id), gc_event_.at(instr_id), + &instr.DeviceContext()); } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index e6f623cb72831b61f24d7646a00925381ea28e89..3a6876a91285721ea95b80e69994f629a15a86a5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -44,6 +44,8 @@ class InterpreterCore { VariableScope* global_scope, const std::vector& feed_names); + ~InterpreterCore(); + paddle::framework::FetchList Run( const std::vector& feed_tensors); @@ -94,11 +96,11 @@ class InterpreterCore { StreamAnalyzer stream_analyzer_; EventManager event_manager_; EventsWaiter main_thread_blocker_; - interpreter::AsyncWorkQueue async_work_queue_; + std::unique_ptr async_work_queue_; details::ExceptionHolder exception_holder_; std::shared_ptr exception_notifier_{nullptr}; - InterpreterCoreGarbageCollector gc_; + std::unique_ptr gc_; std::vector gc_event_; std::atomic op_run_number_{0}; }; diff --git a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.cc b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.cc index 3025e23445e1cae72f3a7829a915f180c4a7a506..59dd44ab9ada6b3bc9c23c7c550e5fc24b69c983 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.cc @@ -28,6 +28,10 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() { queue_ = CreateSingleThreadedWorkQueue(options); } +InterpreterCoreGarbageCollector::~InterpreterCoreGarbageCollector() { + queue_.reset(nullptr); +} + void InterpreterCoreGarbageCollector::Add( std::shared_ptr garbage, paddle::platform::DeviceEvent& event, const platform::DeviceContext* ctx) { diff --git a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h index b1157c861754ce913b66683e1ad52a08e7550a57..166139a73c8f94cfe0cfc323764a85fa905068b0 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h +++ b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h @@ -35,6 +35,8 @@ class InterpreterCoreGarbageCollector { public: InterpreterCoreGarbageCollector(); + ~InterpreterCoreGarbageCollector(); + void Add(std::shared_ptr garbage, // NOLINT paddle::platform::DeviceEvent& event, // NOLINT const platform::DeviceContext* ctx);