From fe6dbdd38b838a6b4d116c7523bc18990b835aee Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Tue, 26 Oct 2021 14:38:10 +0800 Subject: [PATCH] [new-exec] Add cancel for thread pool (#36688) * add align for WorkQueue * add spinlock * merge develop * merge * Add EventsWaiter * update * update * update Error MSG * update EventsWaiter * Add Cancel For ThreadPool * Add UT for Cancel --- .../new_executor/interpretercore_util.h | 2 ++ .../new_executor/nonblocking_threadpool.h | 6 ++++++ .../framework/new_executor/thread_environment.h | 11 ++++++++++- paddle/fluid/framework/new_executor/workqueue.cc | 16 ++++++++++++++++ paddle/fluid/framework/new_executor/workqueue.h | 4 ++++ .../framework/new_executor/workqueue_test.cc | 6 +++++- 6 files changed, 43 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 3c927a8d81..b1e1c02ab9 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -81,6 +81,8 @@ class AsyncWorkQueue { queue_group_->AddTask(static_cast(op_func_type), std::move(fn)); } + void Cancel() { queue_group_->Cancel(); } + AtomicVectorSizeT& AtomicDeps() { return atomic_deps_; } AtomicVectorSizeT& AtomicVarRef() { return atomic_var_ref_; } diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h index 667723c671..6e56532456 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -173,6 +173,12 @@ class ThreadPoolTempl { ec_.Notify(true); } + void WaitThreadsExit() { + for (size_t i = 0; i < thread_data_.size(); ++i) { + thread_data_[i].thread->WaitExit(); + } + } + size_t NumThreads() const { return num_threads_; } int CurrentThreadId() const { diff --git a/paddle/fluid/framework/new_executor/thread_environment.h b/paddle/fluid/framework/new_executor/thread_environment.h index be93627418..eb1ee4de90 100644 --- a/paddle/fluid/framework/new_executor/thread_environment.h +++ b/paddle/fluid/framework/new_executor/thread_environment.h @@ -25,7 +25,16 @@ struct StlThreadEnvironment { class EnvThread { public: explicit EnvThread(std::function f) : thr_(std::move(f)) {} - ~EnvThread() { thr_.join(); } + void WaitExit() { + if (thr_.joinable()) { + thr_.join(); + } + } + ~EnvThread() { + if (thr_.joinable()) { + thr_.join(); + } + } private: std::thread thr_; diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 559c7a2f13..7607b3a297 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -49,6 +49,11 @@ class WorkQueueImpl : public WorkQueue { queue_->AddTask(std::move(fn)); } + void Cancel() override { + queue_->Cancel(); + queue_->WaitThreadsExit(); + } + size_t NumThreads() const override { return queue_->NumThreads(); } private: @@ -69,6 +74,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { size_t QueueGroupNumThreads() const override; + void Cancel() override; + private: std::vector queues_; NonblockingThreadPool* queues_storage_; @@ -136,6 +143,15 @@ size_t WorkQueueGroupImpl::QueueGroupNumThreads() const { return total_num; } +void WorkQueueGroupImpl::Cancel() { + for (auto queue : queues_) { + queue->Cancel(); + } + for (auto queue : queues_) { + queue->WaitThreadsExit(); + } +} + } // namespace std::unique_ptr CreateSingleThreadedWorkQueue( diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue.h index e49ce9df80..3520307c70 100644 --- a/paddle/fluid/framework/new_executor/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue.h @@ -64,6 +64,8 @@ class WorkQueue { virtual size_t NumThreads() const = 0; + virtual void Cancel() = 0; + protected: WorkQueueOptions options_; }; @@ -88,6 +90,8 @@ class WorkQueueGroup { virtual size_t QueueGroupNumThreads() const = 0; + virtual void Cancel() = 0; + protected: std::vector queues_options_; }; diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index c10c4172cd..3ea0096b63 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -83,6 +83,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { events_waiter.WaitEvent(); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); + // Cancel + work_queue->Cancel(); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -119,7 +121,9 @@ TEST(WorkQueue, TestWorkQueueGroup) { ++counter; } }); - // WaitQueueGroupEmpty() + // WaitQueueGroupEmpty events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); + // Cancel + queue_group->Cancel(); } -- GitLab