diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 6586b418ba2781c5cdc46a5e0a221662e076557e..3fcc2fa1014e238a2532ac73f0ab95ceaafe6f24 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -37,5 +37,34 @@ std::unique_ptr CreateSingleThreadedWorkQueue() { return std::move(ptr); } +class MultiThreadedWorkQueue : public WorkQueue { + public: + explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) { + assert(num_threads > 1); + } + + MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete; + + MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete; + + virtual ~MultiThreadedWorkQueue() = default; + + void AddTask(std::function fn) override { + queue_.AddTask(std::move(fn)); + } + + void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } + + size_t NumThreads() override { return queue_.NumThreads(); } + + private: + NonblockingThreadPool queue_; +}; + +std::unique_ptr CreateMultiThreadedWorkQueue(int num_threads) { + std::unique_ptr ptr(new MultiThreadedWorkQueue(num_threads)); + return std::move(ptr); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index f5d0a7f4033ad5b39311e15fb63d76614240dc76..691c78f3df2a3c385799ab18638dff66f4bc6103 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -43,3 +43,33 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum); } + +TEST(WorkQueue, TestMultiThreadedWorkQueue) { + VLOG(1) << "In Test"; + using paddle::framework::WorkQueue; + using paddle::framework::CreateMultiThreadedWorkQueue; + std::atomic finished{false}; + std::atomic counter{0}; + constexpr unsigned kExternalLoopNum = 100; + constexpr unsigned kLoopNum = 1000000; + // CreateSingleThreadedWorkQueue + std::unique_ptr work_queue = CreateMultiThreadedWorkQueue(10); + // NumThreads + EXPECT_EQ(work_queue->NumThreads(), 10u); + // AddTask + EXPECT_EQ(finished.load(), false); + EXPECT_EQ(counter.load(), 0u); + for (unsigned i = 0; i < kExternalLoopNum; ++i) { + work_queue->AddTask([&counter, &finished, kLoopNum]() { + for (unsigned i = 0; i < kLoopNum; ++i) { + ++counter; + } + finished = true; + }); + } + // WaitQueueEmpty + EXPECT_EQ(finished.load(), false); + work_queue->WaitQueueEmpty(); + EXPECT_EQ(finished.load(), true); + EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); +}