未验证 提交 f0c77378 编写于 作者: W wanghuancoder 提交者: GitHub

cancle threadpool before deconstruction interpretorcore (#37034)

* cancle thread when exit, test=develop

* gc to unique_ptr, test=develop

* refine, test=develop

* fix namespace, test=develop
上级 c9763006
...@@ -39,9 +39,11 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, ...@@ -39,9 +39,11 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block,
: place_(place), : place_(place),
block_(block), block_(block),
global_scope_(global_scope), global_scope_(global_scope),
stream_analyzer_(place), stream_analyzer_(place) {
async_work_queue_(kHostNumThreads, &main_thread_blocker_) {
is_build_ = false; is_build_ = false;
async_work_queue_.reset(
new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_));
gc_.reset(new InterpreterCoreGarbageCollector());
feed_names_ = feed_names; feed_names_ = feed_names;
...@@ -55,6 +57,13 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block, ...@@ -55,6 +57,13 @@ InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block,
// convert to run graph // convert to run graph
} }
InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);
async_work_queue_.reset(nullptr);
}
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto FeedInput = [&] { auto FeedInput = [&] {
...@@ -349,16 +358,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -349,16 +358,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
async_work_queue_.PrepareAtomicDeps(dependecy_count_); async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_.PrepareAtomicVarRef(vec_meta_info_); async_work_queue_->PrepareAtomicVarRef(vec_meta_info_);
op_run_number_ = 0; op_run_number_ = 0;
exception_holder_.Clear(); exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) { for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
async_work_queue_.AddTask(vec_instr.at(i).KernelType(), async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[&, i] { RunInstructionAsync(i); }); [&, i] { RunInstructionAsync(i); });
} }
} }
...@@ -380,7 +389,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -380,7 +389,7 @@ void InterpreterCore::ExecuteInstructionList(
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) { const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
auto& next_instr = instr.NextInstructions(); auto& next_instr = instr.NextInstructions();
auto& atomic_deps = async_work_queue_.AtomicDeps(); auto& atomic_deps = async_work_queue_->AtomicDeps();
auto IsReady = [&](size_t next_id) { auto IsReady = [&](size_t next_id) {
return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1; return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
}; };
...@@ -389,7 +398,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -389,7 +398,7 @@ void InterpreterCore::RunNextInstructions(
// move all sync_ops into other threads // move all sync_ops into other threads
for (auto next_id : next_instr.SyncRunIds()) { for (auto next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_.AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
...@@ -409,7 +418,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -409,7 +418,7 @@ void InterpreterCore::RunNextInstructions(
// move async_ops into async_thread // move async_ops into async_thread
for (auto next_id : next_instr.EventRunIds()) { for (auto next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_.AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
...@@ -425,7 +434,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -425,7 +434,7 @@ void InterpreterCore::RunNextInstructions(
continue; continue;
} }
// move rest ops into other threads // move rest ops into other threads
async_work_queue_.AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
...@@ -483,7 +492,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -483,7 +492,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
void InterpreterCore::CheckGC(const Instruction& instr) { void InterpreterCore::CheckGC(const Instruction& instr) {
size_t instr_id = instr.Id(); size_t instr_id = instr.Id();
auto& var_scope = *global_scope_; auto& var_scope = *global_scope_;
auto& atomic_var_ref = async_work_queue_.AtomicVarRef(); auto& atomic_var_ref = async_work_queue_->AtomicVarRef();
for (auto var_id : instr.GCCheckVars()) { for (auto var_id : instr.GCCheckVars()) {
bool is_ready = bool is_ready =
...@@ -493,8 +502,8 @@ void InterpreterCore::CheckGC(const Instruction& instr) { ...@@ -493,8 +502,8 @@ void InterpreterCore::CheckGC(const Instruction& instr) {
continue; continue;
} }
if (is_ready) { if (is_ready) {
gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id), gc_->Add(var_scope.Var(var_id), gc_event_.at(instr_id),
&instr.DeviceContext()); &instr.DeviceContext());
} }
} }
} }
......
...@@ -44,6 +44,8 @@ class InterpreterCore { ...@@ -44,6 +44,8 @@ class InterpreterCore {
VariableScope* global_scope, VariableScope* global_scope,
const std::vector<std::string>& feed_names); const std::vector<std::string>& feed_names);
~InterpreterCore();
paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
...@@ -94,11 +96,11 @@ class InterpreterCore { ...@@ -94,11 +96,11 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventManager event_manager_; EventManager event_manager_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
interpreter::AsyncWorkQueue async_work_queue_; std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
InterpreterCoreGarbageCollector gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
std::atomic<size_t> op_run_number_{0}; std::atomic<size_t> op_run_number_{0};
}; };
......
...@@ -28,6 +28,10 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() { ...@@ -28,6 +28,10 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
queue_ = CreateSingleThreadedWorkQueue(options); queue_ = CreateSingleThreadedWorkQueue(options);
} }
InterpreterCoreGarbageCollector::~InterpreterCoreGarbageCollector() {
queue_.reset(nullptr);
}
void InterpreterCoreGarbageCollector::Add( void InterpreterCoreGarbageCollector::Add(
std::shared_ptr<memory::Allocation> garbage, std::shared_ptr<memory::Allocation> garbage,
paddle::platform::DeviceEvent& event, const platform::DeviceContext* ctx) { paddle::platform::DeviceEvent& event, const platform::DeviceContext* ctx) {
......
...@@ -35,6 +35,8 @@ class InterpreterCoreGarbageCollector { ...@@ -35,6 +35,8 @@ class InterpreterCoreGarbageCollector {
public: public:
InterpreterCoreGarbageCollector(); InterpreterCoreGarbageCollector();
~InterpreterCoreGarbageCollector();
void Add(std::shared_ptr<memory::Allocation> garbage, // NOLINT void Add(std::shared_ptr<memory::Allocation> garbage, // NOLINT
paddle::platform::DeviceEvent& event, // NOLINT paddle::platform::DeviceEvent& event, // NOLINT
const platform::DeviceContext* ctx); const platform::DeviceContext* ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册