From a53460aaa1fc4f89ae71e0fe71bc7fa2e9a64b9d Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Wed, 8 Sep 2021 10:08:27 +0800 Subject: [PATCH] Work queue group (#35470) * Split Tracker and WorkQueue * add WorkQueueGroup * add unittest * fix * update * update * fix compile --- .../framework/new_executor/CMakeLists.txt | 2 +- .../framework/new_executor/interpretercore.cc | 4 +- .../new_executor/nonblocking_threadpool.h | 102 ++++------- .../fluid/framework/new_executor/run_queue.h | 6 +- .../fluid/framework/new_executor/workqueue.cc | 165 ++++++++++++++---- .../fluid/framework/new_executor/workqueue.h | 50 +++++- .../framework/new_executor/workqueue_test.cc | 53 +++++- .../framework/new_executor/workqueue_utils.h | 63 +++++++ 8 files changed, 335 insertions(+), 110 deletions(-) create mode 100644 paddle/fluid/framework/new_executor/workqueue_utils.h diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 23f0e02bb3e..73e4e7ad80c 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(workqueue SRCS workqueue.cc) +cc_library(workqueue SRCS workqueue.cc DEPS enforce) cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index cadc3d3e5ef..63d36ddab3d 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -157,7 +157,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place, garbages_.reset(new GarbageQueue()); max_memory_size_ = static_cast(GetEagerDeletionThreshold()); cur_memory_size_ = 0; - gc_queue_ = CreateSingleThreadedWorkQueue(); + WorkQueueOptions options; + options.num_threads = 1; + gc_queue_ = CreateSingleThreadedWorkQueue(options); feed_names_ = feed_names; diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h index 2ea1c49f98f..56edcecd17f 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -19,45 +19,46 @@ namespace paddle { namespace framework { -class CounterTracker { +class TaskTracker { public: - explicit CounterTracker(std::atomic* counter, EventCount* ec) - : counter_(counter), ec_(ec) { - counter_->fetch_add(1, std::memory_order_relaxed); - } + TaskTracker() : wait_empty_cv_(1) {} - ~CounterTracker() { - if (counter_ != nullptr) { - if (1 == counter_->fetch_sub(1, std::memory_order_relaxed)) { - ec_->Notify(true); - } - } - } + TaskTracker(const TaskTracker&) = delete; - CounterTracker(CounterTracker&& other) - : counter_(other.counter_), ec_(other.ec_) { - other.counter_ = nullptr; - other.ec_ = nullptr; - } + TaskTracker& operator=(const TaskTracker&) = delete; - CounterTracker& operator=(CounterTracker&& other) { - counter_ = other.counter_; - ec_ = other.ec_; - other.counter_ = nullptr; - other.ec_ = nullptr; - return *this; - } + ~TaskTracker() = default; - CounterTracker(const CounterTracker& other) - : counter_(other.counter_), ec_(other.ec_) { - counter_->fetch_add(1, std::memory_order_relaxed); + void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } + + void SubCounter() { + if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { + wait_empty_cv_.Notify(true); + } } - CounterTracker& operator=(const CounterTracker&) = delete; + // only one user can wait at any time + void WaitTaskNumToZero() { + bool waiting = false; + if (!wait_empty_.compare_exchange_strong(waiting, true, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + abort(); + } + EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0); + wait_empty_cv_.Prewait(); + if (num_tasks_.load(std::memory_order_relaxed) == 0) { + wait_empty_cv_.CancelWait(); + } else { + wait_empty_cv_.CommitWait(w); + } + wait_empty_.store(false); + } private: - std::atomic* counter_{nullptr}; - EventCount* ec_{nullptr}; + std::atomic num_tasks_{0}; + EventCount wait_empty_cv_; + std::atomic wait_empty_{false}; }; template @@ -66,9 +67,6 @@ class ThreadPoolTempl { typedef typename Environment::Task Task; typedef RunQueue Queue; - explicit ThreadPoolTempl(int num_threads, Environment env = Environment()) - : ThreadPoolTempl(num_threads, true, env) {} - ThreadPoolTempl(int num_threads, bool allow_spinning, Environment env = Environment()) : env_(env), @@ -80,10 +78,7 @@ class ThreadPoolTempl { spinning_(0), done_(false), cancelled_(false), - ec_(num_threads_), - wait_empty_(false), - wait_empty_ec_(1), - num_tasks_(0) { + ec_(num_threads_) { // Calculate coprimes of all numbers [1, num_threads]. // Coprimes are used for random walks over all threads in Steal // and NonEmptyQueueIndex. Iteration is based on the fact that if we take @@ -146,15 +141,13 @@ class ThreadPoolTempl { } void AddTaskWithHint(std::function fn, int start, int limit) { - Task t = env_.CreateTask([ - task = std::move(fn), raii = CounterTracker(&num_tasks_, &wait_empty_ec_) - ]() mutable { task(); }); + Task t = env_.CreateTask(std::move(fn)); PerThread* pt = GetPerThread(); if (pt->pool == this) { // Worker thread of this pool, push onto the thread's queue. Queue& q = thread_data_[pt->thread_id].queue; t = q.PushFront(std::move(t)); - } else if (wait_empty_.load() == false) { + } else { // A free-standing thread (or worker of another pool), push onto a random // queue. assert(start < limit); @@ -179,29 +172,6 @@ class ThreadPoolTempl { } } - void WaitQueueEmpty() { - bool waiting = wait_empty_.load(); - assert(waiting == false); - if (waiting || - !wait_empty_.compare_exchange_strong(waiting, true, - std::memory_order_acquire)) { - abort(); - } - EventCount::Waiter* w = wait_empty_ec_.GetWaiter(0); - wait_empty_ec_.Prewait(); - if (num_tasks_.load() == 0) { - wait_empty_ec_.CancelWait(); - } else { - wait_empty_ec_.CommitWait(w); - } - waiting = true; - if (!waiting || - !wait_empty_.compare_exchange_strong(waiting, false, - std::memory_order_acquire)) { - abort(); - } - } - void Cancel() { cancelled_ = true; done_ = true; @@ -300,10 +270,6 @@ class ThreadPoolTempl { std::atomic cancelled_; EventCount ec_; - std::atomic wait_empty_; - EventCount wait_empty_ec_; - std::atomic num_tasks_; - // Main worker thread loop. void WorkerLoop(int thread_id) { PerThread* pt = GetPerThread(); diff --git a/paddle/fluid/framework/new_executor/run_queue.h b/paddle/fluid/framework/new_executor/run_queue.h index ead37fa4695..707aadd3158 100644 --- a/paddle/fluid/framework/new_executor/run_queue.h +++ b/paddle/fluid/framework/new_executor/run_queue.h @@ -194,7 +194,7 @@ class RunQueue { private: static const unsigned kMask = kSize - 1; static const unsigned kMask2 = (kSize << 1) - 1; - struct Elem { + struct alignas(64) Elem { std::atomic state; Work w; }; @@ -212,8 +212,8 @@ class RunQueue { // position, these conditions would be indistinguishable); (2) obtain // consistent snapshot of front_/back_ for Size operation using the // modification counters. - std::atomic front_; - std::atomic back_; + alignas(64) std::atomic front_; + alignas(64) std::atomic back_; Elem array_[kSize]; // SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false, diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc index 3fcc2fa1014..184d9d69984 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -6,63 +6,168 @@ #include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" +#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { +namespace { -class SingleThreadedWorkQueue : public WorkQueue { +class WorkQueueImpl : public WorkQueue { public: - SingleThreadedWorkQueue() : queue_(1) {} - - SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete; - - SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete; + explicit WorkQueueImpl(const WorkQueueOptions& options) + : WorkQueue(options), queue_(nullptr), tracker_(nullptr) { + if (options_.track_task) { + tracker_ = new TaskTracker; + } + queue_ = new NonblockingThreadPool(options_.num_threads, + options_.allow_spinning); + } - virtual ~SingleThreadedWorkQueue() = default; + virtual ~WorkQueueImpl() { + delete tracker_; + delete queue_; + } void AddTask(std::function fn) override { - queue_.AddTask(std::move(fn)); + if (tracker_ != nullptr) { + fn = [ + task = std::move(fn), raii = CounterGuard(tracker_) + ]() mutable { + task(); + }; + } + queue_->AddTask(std::move(fn)); } - void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } + void WaitQueueEmpty() override { + if (tracker_ == nullptr) { + PADDLE_THROW( + platform::errors::Unavailable("set WorkQueueOptions.track_task = " + "true before call this interface.")); + } + tracker_->WaitTaskNumToZero(); + } - size_t NumThreads() override { return queue_.NumThreads(); } + size_t NumThreads() const override { return queue_->NumThreads(); } private: - NonblockingThreadPool queue_; + NonblockingThreadPool* queue_; + TaskTracker* tracker_; }; -std::unique_ptr CreateSingleThreadedWorkQueue() { - std::unique_ptr ptr(new SingleThreadedWorkQueue); - return std::move(ptr); +class WorkQueueGroupImpl : public WorkQueueGroup { + public: + explicit WorkQueueGroupImpl( + const std::vector& queue_options); + + ~WorkQueueGroupImpl(); + + void AddTask(size_t queue_idx, std::function fn) override; + + void WaitQueueGroupEmpty() override; + + size_t QueueNumThreads(size_t queue_idx) const override; + + size_t QueueGroupNumThreads() const override; + + private: + std::vector queues_; + NonblockingThreadPool* queues_storage_; + TaskTracker* tracker_; +}; + +WorkQueueGroupImpl::WorkQueueGroupImpl( + const std::vector& queues_options) + : WorkQueueGroup(queues_options), + queues_storage_(nullptr), + tracker_(nullptr) { + size_t num_queues = queues_options_.size(); + queues_.resize(num_queues); + void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues); + queues_storage_ = reinterpret_cast(buffer); + for (size_t idx = 0; idx < num_queues; ++idx) { + const auto& options = queues_options_[idx]; + if (options.track_task && tracker_ == nullptr) { + tracker_ = new TaskTracker; + } + queues_[idx] = new (&queues_storage_[idx]) + NonblockingThreadPool(options.num_threads, options.allow_spinning); + } } -class MultiThreadedWorkQueue : public WorkQueue { - public: - explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) { - assert(num_threads > 1); +WorkQueueGroupImpl::~WorkQueueGroupImpl() { + for (auto queue : queues_) { + queue->~NonblockingThreadPool(); } + delete tracker_; + free(queues_storage_); +} - MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete; +void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function fn) { + assert(queue_idx < queues_.size()); + if (queues_options_.at(queue_idx).track_task) { + fn = [ + task = std::move(fn), raii = CounterGuard(tracker_) + ]() mutable { + task(); + }; + } + queues_[queue_idx]->AddTask(std::move(fn)); +} - MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete; +void WorkQueueGroupImpl::WaitQueueGroupEmpty() { + if (nullptr == tracker_) { + PADDLE_THROW(platform::errors::Unavailable( + "set WorkQueueOptions.track_task = true for at least one of queues " + "before call this interface.")); + } + tracker_->WaitTaskNumToZero(); +} - virtual ~MultiThreadedWorkQueue() = default; +size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const { + assert(queue_idx < queues_.size()); + return queues_.at(queue_idx)->NumThreads(); +} - void AddTask(std::function fn) override { - queue_.AddTask(std::move(fn)); +size_t WorkQueueGroupImpl::QueueGroupNumThreads() const { + size_t total_num = 0; + for (auto queue : queues_) { + total_num += queue->NumThreads(); } + return total_num; +} - void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } +} // namespace - size_t NumThreads() override { return queue_.NumThreads(); } +std::unique_ptr CreateSingleThreadedWorkQueue( + const WorkQueueOptions& options) { + PADDLE_ENFORCE_EQ(options.num_threads, 1u, + platform::errors::InvalidArgument( + "For a SingleThreadedWorkQueue, " + "WorkQueueOptions.num_threads must equals to 1.")); + std::unique_ptr ptr(new WorkQueueImpl(options)); + return std::move(ptr); +} - private: - NonblockingThreadPool queue_; -}; +std::unique_ptr CreateMultiThreadedWorkQueue( + const WorkQueueOptions& options) { + PADDLE_ENFORCE_GT( + options.num_threads, 1u, + platform::errors::InvalidArgument("For a MultiThreadedWorkQueue, " + "WorkQueueOptions.num_threads must be " + "greater than 1.")); + std::unique_ptr ptr(new WorkQueueImpl(options)); + return std::move(ptr); +} -std::unique_ptr CreateMultiThreadedWorkQueue(int num_threads) { - std::unique_ptr ptr(new MultiThreadedWorkQueue(num_threads)); +std::unique_ptr CreateWorkQueueGroup( + const std::vector& queues_options) { + PADDLE_ENFORCE_GT(queues_options.size(), 1u, + platform::errors::InvalidArgument( + "For a WorkQueueGroup, the number of WorkQueueOptions " + "must be greater than 1.")); + std::unique_ptr ptr(new WorkQueueGroupImpl(queues_options)); return std::move(ptr); } diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue.h index 402bd215d63..32e90641bbc 100644 --- a/paddle/fluid/framework/new_executor/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue.h @@ -16,13 +16,20 @@ #include #include +#include namespace paddle { namespace framework { +struct WorkQueueOptions { + size_t num_threads{0}; + bool allow_spinning{true}; + bool track_task{false}; +}; + class WorkQueue { public: - WorkQueue() = default; + explicit WorkQueue(const WorkQueueOptions& options) : options_(options) {} WorkQueue(const WorkQueue&) = delete; @@ -32,14 +39,49 @@ class WorkQueue { virtual void AddTask(std::function fn) = 0; + // set WorkQueueOptions.track_task = true before call this + // interface, otherwise will abort() virtual void WaitQueueEmpty() = 0; - virtual size_t NumThreads() = 0; + virtual size_t NumThreads() const = 0; + + protected: + WorkQueueOptions options_; }; -std::unique_ptr CreateSingleThreadedWorkQueue(); +class WorkQueueGroup { + public: + explicit WorkQueueGroup(const std::vector& queues_options) + : queues_options_(queues_options) {} + + WorkQueueGroup(const WorkQueueGroup&) = delete; + + WorkQueueGroup& operator=(const WorkQueueGroup&) = delete; + + virtual ~WorkQueueGroup() = default; + + virtual void AddTask(size_t queue_idx, std::function fn) = 0; + + // set WorkQueueOptions.track_task = true for at least one of queues + // before call this interface, otherwise will abort() + virtual void WaitQueueGroupEmpty() = 0; + + virtual size_t QueueNumThreads(size_t queue_idx) const = 0; + + virtual size_t QueueGroupNumThreads() const = 0; + + protected: + std::vector queues_options_; +}; + +std::unique_ptr CreateSingleThreadedWorkQueue( + const WorkQueueOptions& options); + +std::unique_ptr CreateMultiThreadedWorkQueue( + const WorkQueueOptions& options); -std::unique_ptr CreateMultiThreadedWorkQueue(int num_threads); +std::unique_ptr CreateWorkQueueGroup( + const std::vector& queues_options); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc index 691c78f3df2..cec1274259e 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -19,13 +19,17 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { VLOG(1) << "In Test"; + using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueue; using paddle::framework::CreateSingleThreadedWorkQueue; std::atomic finished{false}; std::atomic counter{0}; constexpr unsigned kLoopNum = 1000000; // CreateSingleThreadedWorkQueue - std::unique_ptr work_queue = CreateSingleThreadedWorkQueue(); + WorkQueueOptions options; + options.num_threads = 1; + options.track_task = true; + auto work_queue = CreateSingleThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 1u); // AddTask @@ -46,14 +50,18 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { TEST(WorkQueue, TestMultiThreadedWorkQueue) { VLOG(1) << "In Test"; + using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueue; using paddle::framework::CreateMultiThreadedWorkQueue; std::atomic finished{false}; std::atomic counter{0}; constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kLoopNum = 1000000; - // CreateSingleThreadedWorkQueue - std::unique_ptr work_queue = CreateMultiThreadedWorkQueue(10); + // CreateMultiThreadedWorkQueue + WorkQueueOptions options; + options.num_threads = 10; + options.track_task = true; + auto work_queue = CreateMultiThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 10u); // AddTask @@ -73,3 +81,42 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); } + +TEST(WorkQueue, TestWorkQueueGroup) { + using paddle::framework::WorkQueueOptions; + using paddle::framework::WorkQueueGroup; + using paddle::framework::CreateWorkQueueGroup; + std::atomic finished{false}; + std::atomic counter{0}; + constexpr unsigned kExternalLoopNum = 100; + constexpr unsigned kLoopNum = 1000000; + // CreateMultiThreadedWorkQueue + WorkQueueOptions sq_options; + sq_options.num_threads = 1; + sq_options.track_task = true; + WorkQueueOptions mq_options; + mq_options.num_threads = 10; + mq_options.track_task = true; + auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); + // NumThreads + EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); + EXPECT_EQ(queue_group->QueueNumThreads(1), 10u); + EXPECT_EQ(queue_group->QueueGroupNumThreads(), 11u); + // AddTask + EXPECT_EQ(counter.load(), 0u); + for (unsigned i = 0; i < kExternalLoopNum; ++i) { + queue_group->AddTask(1, [&counter, &finished, kLoopNum]() { + for (unsigned i = 0; i < kLoopNum; ++i) { + ++counter; + } + }); + } + queue_group->AddTask(0, [&counter, &finished, kLoopNum]() { + for (unsigned i = 0; i < kLoopNum; ++i) { + ++counter; + } + }); + // WaitQueueGroupEmpty() + queue_group->WaitQueueGroupEmpty(); + EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); +} diff --git a/paddle/fluid/framework/new_executor/workqueue_utils.h b/paddle/fluid/framework/new_executor/workqueue_utils.h new file mode 100644 index 00000000000..00183eadcbb --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue_utils.h @@ -0,0 +1,63 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +template +class CounterGuard { + public: + explicit CounterGuard(Holder* holder) : counter_holder_(holder) { + assert(holder != nullptr); + counter_holder_->AddCounter(); + } + + ~CounterGuard() { + if (counter_holder_ != nullptr) { + counter_holder_->SubCounter(); + } + } + + CounterGuard(CounterGuard&& other) : counter_holder_(other.counter_holder_) { + other.counter_holder_ = nullptr; + } + + CounterGuard& operator=(CounterGuard&& other) { + counter_holder_ = other.counter_holder_; + other.counter_holder_ = nullptr; + return *this; + } + + // copy constructor deleted, we define this for std::function + // never use it directly + CounterGuard(const CounterGuard& other) { + PADDLE_THROW(platform::errors::Unavailable( + "Never use the copy constructor of CounterGuard.")); + } + + CounterGuard& operator=(const CounterGuard&) = delete; + + private: + Holder* counter_holder_{nullptr}; +}; + +} // namespace framework +} // namespace paddle -- GitLab