From e4a8815d11b0042a5faade179239d64f8937c72a Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Thu, 26 Aug 2021 11:16:00 +0800 Subject: [PATCH] add temporary MultiThreadedWorkQueue (#35158) --- .../fluid/framework/new_executor/workqueue.cc | 29 ++++++++++++++++++ .../framework/new_executor/workqueue_test.cc | 30 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 6586b418ba2..3fcc2fa1014 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -37,5 +37,34 @@ std::unique_ptr CreateSingleThreadedWorkQueue() { return std::move(ptr); } +class MultiThreadedWorkQueue : public WorkQueue { + public: + explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) { + assert(num_threads > 1); + } + + MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete; + + MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete; + + virtual ~MultiThreadedWorkQueue() = default; + + void AddTask(std::function fn) override { + queue_.AddTask(std::move(fn)); + } + + void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } + + size_t NumThreads() override { return queue_.NumThreads(); } + + private: + NonblockingThreadPool queue_; +}; + +std::unique_ptr CreateMultiThreadedWorkQueue(int num_threads) { + std::unique_ptr ptr(new MultiThreadedWorkQueue(num_threads)); + return std::move(ptr); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index f5d0a7f4033..691c78f3df2 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -43,3 +43,33 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum); } + +TEST(WorkQueue, TestMultiThreadedWorkQueue) { + VLOG(1) << "In Test"; + using paddle::framework::WorkQueue; + using paddle::framework::CreateMultiThreadedWorkQueue; + std::atomic finished{false}; + std::atomic counter{0}; + constexpr unsigned kExternalLoopNum = 100; + constexpr unsigned kLoopNum = 1000000; + // CreateSingleThreadedWorkQueue + std::unique_ptr work_queue = CreateMultiThreadedWorkQueue(10); + // NumThreads + EXPECT_EQ(work_queue->NumThreads(), 10u); + // AddTask + EXPECT_EQ(finished.load(), false); + EXPECT_EQ(counter.load(), 0u); + for (unsigned i = 0; i < kExternalLoopNum; ++i) { + work_queue->AddTask([&counter, &finished, kLoopNum]() { + for (unsigned i = 0; i < kLoopNum; ++i) { + ++counter; + } + finished = true; + }); + } + // WaitQueueEmpty + EXPECT_EQ(finished.load(), false); + work_queue->WaitQueueEmpty(); + EXPECT_EQ(finished.load(), true); + EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); +} -- GitLab