diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 3c927a8d81d16325c1c82709bfed2c07511e9ee3..b1e1c02ab9513b79ae34a19b8f2d6907380716ce 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -81,6 +81,8 @@ class AsyncWorkQueue { queue_group_->AddTask(static_cast(op_func_type), std::move(fn)); } + void Cancel() { queue_group_->Cancel(); } + AtomicVectorSizeT& AtomicDeps() { return atomic_deps_; } AtomicVectorSizeT& AtomicVarRef() { return atomic_var_ref_; } diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h index 667723c67165cca020a349f0b7b9e98604c686fc..6e56532456c6fd6b5cfbb4a14601e1e335495e73 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -173,6 +173,12 @@ class ThreadPoolTempl { ec_.Notify(true); } + void WaitThreadsExit() { + for (size_t i = 0; i < thread_data_.size(); ++i) { + thread_data_[i].thread->WaitExit(); + } + } + size_t NumThreads() const { return num_threads_; } int CurrentThreadId() const { diff --git a/paddle/fluid/framework/new_executor/thread_environment.h b/paddle/fluid/framework/new_executor/thread_environment.h index be936274186f4fb2ed2f6cb47ae6d1c99f2271c1..eb1ee4de90898d851398293246f24485fc846b53 100644 --- a/paddle/fluid/framework/new_executor/thread_environment.h +++ b/paddle/fluid/framework/new_executor/thread_environment.h @@ -25,7 +25,16 @@ struct StlThreadEnvironment { class EnvThread { public: explicit EnvThread(std::function f) : thr_(std::move(f)) {} - ~EnvThread() { thr_.join(); } + void WaitExit() { + if (thr_.joinable()) { + thr_.join(); + } + } + ~EnvThread() { + if (thr_.joinable()) { + thr_.join(); + } + } private: std::thread thr_; diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 559c7a2f13785f05c1173c2306a591dd66a6332f..7607b3a297f843e668b6f8af6d1da7f4308ede37 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -49,6 +49,11 @@ class WorkQueueImpl : public WorkQueue { queue_->AddTask(std::move(fn)); } + void Cancel() override { + queue_->Cancel(); + queue_->WaitThreadsExit(); + } + size_t NumThreads() const override { return queue_->NumThreads(); } private: @@ -69,6 +74,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { size_t QueueGroupNumThreads() const override; + void Cancel() override; + private: std::vector queues_; NonblockingThreadPool* queues_storage_; @@ -136,6 +143,15 @@ size_t WorkQueueGroupImpl::QueueGroupNumThreads() const { return total_num; } +void WorkQueueGroupImpl::Cancel() { + for (auto queue : queues_) { + queue->Cancel(); + } + for (auto queue : queues_) { + queue->WaitThreadsExit(); + } +} + } // namespace std::unique_ptr CreateSingleThreadedWorkQueue( diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue.h index e49ce9df8054ad2f8d80ca120ce35d80fc75a82a..3520307c70b8e421d20a76c659917b2b91bb71df 100644 --- a/paddle/fluid/framework/new_executor/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue.h @@ -64,6 +64,8 @@ class WorkQueue { virtual size_t NumThreads() const = 0; + virtual void Cancel() = 0; + protected: WorkQueueOptions options_; }; @@ -88,6 +90,8 @@ class WorkQueueGroup { virtual size_t QueueGroupNumThreads() const = 0; + virtual void Cancel() = 0; + protected: std::vector queues_options_; }; diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index c10c4172cd5cd64cfac054cf4195fc9b1beec434..3ea0096b631e8208856530b4690954f0689cff18 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -83,6 +83,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { events_waiter.WaitEvent(); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); + // Cancel + work_queue->Cancel(); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -119,7 +121,9 @@ TEST(WorkQueue, TestWorkQueueGroup) { ++counter; } }); - // WaitQueueGroupEmpty() + // WaitQueueGroupEmpty events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); + // Cancel + queue_group->Cancel(); }