未验证 提交 fe6dbdd3 编写于 作者: L liutiexing 提交者: GitHub

[new-exec] Add cancel for thread pool (#36688)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* update

* update

* update Error MSG

* update EventsWaiter

* Add Cancel For ThreadPool

* Add UT for Cancel
上级 eb9ef885
...@@ -81,6 +81,8 @@ class AsyncWorkQueue { ...@@ -81,6 +81,8 @@ class AsyncWorkQueue {
queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn)); queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
} }
void Cancel() { queue_group_->Cancel(); }
AtomicVectorSizeT& AtomicDeps() { return atomic_deps_; } AtomicVectorSizeT& AtomicDeps() { return atomic_deps_; }
AtomicVectorSizeT& AtomicVarRef() { return atomic_var_ref_; } AtomicVectorSizeT& AtomicVarRef() { return atomic_var_ref_; }
......
...@@ -173,6 +173,12 @@ class ThreadPoolTempl { ...@@ -173,6 +173,12 @@ class ThreadPoolTempl {
ec_.Notify(true); ec_.Notify(true);
} }
void WaitThreadsExit() {
for (size_t i = 0; i < thread_data_.size(); ++i) {
thread_data_[i].thread->WaitExit();
}
}
size_t NumThreads() const { return num_threads_; } size_t NumThreads() const { return num_threads_; }
int CurrentThreadId() const { int CurrentThreadId() const {
......
...@@ -25,7 +25,16 @@ struct StlThreadEnvironment { ...@@ -25,7 +25,16 @@ struct StlThreadEnvironment {
class EnvThread { class EnvThread {
public: public:
explicit EnvThread(std::function<void()> f) : thr_(std::move(f)) {} explicit EnvThread(std::function<void()> f) : thr_(std::move(f)) {}
~EnvThread() { thr_.join(); } void WaitExit() {
if (thr_.joinable()) {
thr_.join();
}
}
~EnvThread() {
if (thr_.joinable()) {
thr_.join();
}
}
private: private:
std::thread thr_; std::thread thr_;
......
...@@ -49,6 +49,11 @@ class WorkQueueImpl : public WorkQueue { ...@@ -49,6 +49,11 @@ class WorkQueueImpl : public WorkQueue {
queue_->AddTask(std::move(fn)); queue_->AddTask(std::move(fn));
} }
void Cancel() override {
queue_->Cancel();
queue_->WaitThreadsExit();
}
size_t NumThreads() const override { return queue_->NumThreads(); } size_t NumThreads() const override { return queue_->NumThreads(); }
private: private:
...@@ -69,6 +74,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { ...@@ -69,6 +74,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup {
size_t QueueGroupNumThreads() const override; size_t QueueGroupNumThreads() const override;
void Cancel() override;
private: private:
std::vector<NonblockingThreadPool*> queues_; std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_; NonblockingThreadPool* queues_storage_;
...@@ -136,6 +143,15 @@ size_t WorkQueueGroupImpl::QueueGroupNumThreads() const { ...@@ -136,6 +143,15 @@ size_t WorkQueueGroupImpl::QueueGroupNumThreads() const {
return total_num; return total_num;
} }
void WorkQueueGroupImpl::Cancel() {
for (auto queue : queues_) {
queue->Cancel();
}
for (auto queue : queues_) {
queue->WaitThreadsExit();
}
}
} // namespace } // namespace
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue( std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
......
...@@ -64,6 +64,8 @@ class WorkQueue { ...@@ -64,6 +64,8 @@ class WorkQueue {
virtual size_t NumThreads() const = 0; virtual size_t NumThreads() const = 0;
virtual void Cancel() = 0;
protected: protected:
WorkQueueOptions options_; WorkQueueOptions options_;
}; };
...@@ -88,6 +90,8 @@ class WorkQueueGroup { ...@@ -88,6 +90,8 @@ class WorkQueueGroup {
virtual size_t QueueGroupNumThreads() const = 0; virtual size_t QueueGroupNumThreads() const = 0;
virtual void Cancel() = 0;
protected: protected:
std::vector<WorkQueueOptions> queues_options_; std::vector<WorkQueueOptions> queues_options_;
}; };
......
...@@ -83,6 +83,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -83,6 +83,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
events_waiter.WaitEvent(); events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel
work_queue->Cancel();
} }
TEST(WorkQueue, TestWorkQueueGroup) { TEST(WorkQueue, TestWorkQueueGroup) {
...@@ -119,7 +121,9 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -119,7 +121,9 @@ TEST(WorkQueue, TestWorkQueueGroup) {
++counter; ++counter;
} }
}); });
// WaitQueueGroupEmpty() // WaitQueueGroupEmpty
events_waiter.WaitEvent(); events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
// Cancel
queue_group->Cancel();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册