diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 80f9d343de0556daad86a0f04ecf0a827cb84083..d7204b816238e91cd60a7fe8f57e3d63a5e1debe 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -1,4 +1,6 @@ cc_library(interpretercore SRCS interpretercore.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${PYBIND_DEPS} profiler) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${PYBIND_DEPS} profiler) +cc_library(workqueue SRCS workqueue.cc) +cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) diff --git a/paddle/fluid/framework/new_executor/event_count.h b/paddle/fluid/framework/new_executor/event_count.h new file mode 100644 index 0000000000000000000000000000000000000000..f374456ca38141859d42d451a291b82dfbfa85ca --- /dev/null +++ b/paddle/fluid/framework/new_executor/event_count.h @@ -0,0 +1,272 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Dmitry Vyukov +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// EventCount allows to wait for arbitrary predicates in non-blocking +// algorithms. Think of condition variable, but wait predicate does not need to +// be protected by a mutex. Usage: +// Waiting thread does: +// +// if (predicate) +// return act(); +// EventCount::Waiter& w = waiters[my_index]; +// ec.Prewait(&w); +// if (predicate) { +// ec.CancelWait(&w); +// return act(); +// } +// ec.CommitWait(&w); +// +// Notifying thread does: +// +// predicate = true; +// ec.Notify(true); +// +// Notify is cheap if there are no waiting threads. Prewait/CommitWait are not +// cheap, but they are executed only if the preceding predicate check has +// failed. +// +// Algorithm outline: +// There are two main variables: predicate (managed by user) and state_. +// Operation closely resembles Dekker mutual algorithm: +// https://en.wikipedia.org/wiki/Dekker%27s_algorithm +// Waiting thread sets state_ then checks predicate, Notifying thread sets +// predicate then checks state_. Due to seq_cst fences in between these +// operations it is guaranteed than either waiter will see predicate change +// and won't block, or notifying thread will see state_ change and will unblock +// the waiter, or both. But it can't happen that both threads don't see each +// other changes, which would lead to deadlock. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { + +class EventCount { + public: + class Waiter; + + explicit EventCount(size_t waiter_num) : state_(kStackMask) { + assert(waiter_num < (1 << kWaiterBits) - 1); + void* buffer = malloc(sizeof(Waiter) * waiter_num); + if (buffer == nullptr) { + return; + } + waiters_ = reinterpret_cast(buffer); + waiter_num_ = waiter_num; + for (size_t i = 0; i < waiter_num_; ++i) { + new (&waiters_[i]) Waiter; + } + } + + EventCount(const EventCount&) = delete; + + void operator=(const EventCount&) = delete; + + ~EventCount() { + // Ensure there are no waiters. + assert(state_.load() == kStackMask); + free(waiters_); + } + + Waiter* GetWaiter(size_t waiter_index) { + assert(waiter_index < waiter_num_); + return &waiters_[waiter_index]; + } + + // Prewait prepares for waiting. + // After calling Prewait, the thread must re-check the wait predicate + // and then call either CancelWait or CommitWait. + void Prewait() { + uint64_t state = state_.load(std::memory_order_relaxed); + for (;;) { + CheckState(state); + uint64_t newstate = state + kWaiterInc; + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_seq_cst)) + return; + } + } + + // CommitWait commits waiting after Prewait. + void CommitWait(Waiter* w) { + assert((w->epoch & ~kEpochMask) == 0); + w->state = Waiter::kNotSignaled; + const uint64_t me = (w - &waiters_[0]) | w->epoch; + uint64_t state = state_.load(std::memory_order_seq_cst); + for (;;) { + CheckState(state, true); + uint64_t newstate; + if ((state & kSignalMask) != 0) { + // Consume the signal and return immidiately. + newstate = state - kWaiterInc - kSignalInc; + } else { + // Remove this thread from pre-wait counter and add to the waiter stack. + newstate = ((state & kWaiterMask) - kWaiterInc) | me; + w->next.store(state & (kStackMask | kEpochMask), + std::memory_order_relaxed); + } + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_acq_rel)) { + if ((state & kSignalMask) == 0) { + w->epoch += kEpochInc; + Park(w); + } + return; + } + } + } + + // CancelWait cancels effects of the previous Prewait call. + void CancelWait() { + uint64_t state = state_.load(std::memory_order_relaxed); + for (;;) { + CheckState(state, true); + uint64_t newstate = state - kWaiterInc; + // We don't know if the thread was also notified or not, + // so we should not consume a signal unconditionaly. + // Only if number of waiters is equal to number of signals, + // we know that the thread was notified and we must take away the signal. + if (((state & kWaiterMask) >> kWaiterShift) == + ((state & kSignalMask) >> kSignalShift)) + newstate -= kSignalInc; + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_acq_rel)) + return; + } + } + + // Notify wakes one or all waiting threads. + // Must be called after changing the associated wait predicate. + void Notify(bool notify_all) { + std::atomic_thread_fence(std::memory_order_seq_cst); + uint64_t state = state_.load(std::memory_order_acquire); + for (;;) { + CheckState(state); + const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; + const uint64_t signals = (state & kSignalMask) >> kSignalShift; + // Easy case: no waiters. + if ((state & kStackMask) == kStackMask && waiters == signals) return; + uint64_t newstate; + if (notify_all) { + // Empty wait stack and set signal to number of pre-wait threads. + newstate = + (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask; + } else if (signals < waiters) { + // There is a thread in pre-wait state, unblock it. + newstate = state + kSignalInc; + } else { + // Pop a waiter from list and unpark it. + Waiter* w = &waiters_[state & kStackMask]; + uint64_t next = w->next.load(std::memory_order_relaxed); + newstate = (state & (kWaiterMask | kSignalMask)) | next; + } + CheckState(newstate); + if (state_.compare_exchange_weak(state, newstate, + std::memory_order_acq_rel)) { + if (!notify_all && (signals < waiters)) + return; // unblocked pre-wait thread + if ((state & kStackMask) == kStackMask) return; + Waiter* w = &waiters_[state & kStackMask]; + if (!notify_all) w->next.store(kStackMask, std::memory_order_relaxed); + Unpark(w); + return; + } + } + } + + class Waiter { + friend class EventCount; + // Align to 128 byte boundary to prevent false sharing with other Waiter + // objects in the same vector. + alignas(128) std::atomic next; + std::mutex mu; + std::condition_variable cv; + uint64_t epoch = 0; + unsigned state = kNotSignaled; + enum { + kNotSignaled, + kWaiting, + kSignaled, + }; + }; + + private: + // State_ layout: + // - low kWaiterBits is a stack of waiters committed wait + // (indexes in waiters_ array are used as stack elements, + // kStackMask means empty stack). + // - next kWaiterBits is count of waiters in prewait state. + // - next kWaiterBits is count of pending signals. + // - remaining bits are ABA counter for the stack. + // (stored in Waiter node and incremented on push). + static const uint64_t kWaiterBits = 14; + static const uint64_t kStackMask = (1ull << kWaiterBits) - 1; + static const uint64_t kWaiterShift = kWaiterBits; + static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) + << kWaiterShift; + static const uint64_t kWaiterInc = 1ull << kWaiterShift; + static const uint64_t kSignalShift = 2 * kWaiterBits; + static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1) + << kSignalShift; + static const uint64_t kSignalInc = 1ull << kSignalShift; + static const uint64_t kEpochShift = 3 * kWaiterBits; + static const uint64_t kEpochBits = 64 - kEpochShift; + static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; + static const uint64_t kEpochInc = 1ull << kEpochShift; + std::atomic state_; + Waiter* waiters_{nullptr}; + size_t waiter_num_{0}; + + static void CheckState(uint64_t state, bool waiter = false) { + static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem"); + const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; + const uint64_t signals = (state & kSignalMask) >> kSignalShift; + assert(waiters >= signals); + assert(waiters < (1 << kWaiterBits) - 1); + assert(!waiter || waiters > 0); + (void)waiters; + (void)signals; + } + + void Park(Waiter* w) { + std::unique_lock lock(w->mu); + while (w->state != Waiter::kSignaled) { + w->state = Waiter::kWaiting; + w->cv.wait(lock); + } + } + + void Unpark(Waiter* w) { + for (Waiter* next; w; w = next) { + uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask; + next = wnext == kStackMask ? nullptr : &waiters_[wnext]; + unsigned state; + { + std::unique_lock lock(w->mu); + state = w->state; + w->state = Waiter::kSignaled; + } + // Avoid notifying if it wasn't waiting. + if (state == Waiter::kWaiting) w->cv.notify_one(); + } + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..2ea1c49f98fa3126cab2fed97771416cb643e236 --- /dev/null +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -0,0 +1,516 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Dmitry Vyukov +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/new_executor/event_count.h" +#include "paddle/fluid/framework/new_executor/run_queue.h" +#include "paddle/fluid/framework/new_executor/thread_environment.h" + +namespace paddle { +namespace framework { + +class CounterTracker { + public: + explicit CounterTracker(std::atomic* counter, EventCount* ec) + : counter_(counter), ec_(ec) { + counter_->fetch_add(1, std::memory_order_relaxed); + } + + ~CounterTracker() { + if (counter_ != nullptr) { + if (1 == counter_->fetch_sub(1, std::memory_order_relaxed)) { + ec_->Notify(true); + } + } + } + + CounterTracker(CounterTracker&& other) + : counter_(other.counter_), ec_(other.ec_) { + other.counter_ = nullptr; + other.ec_ = nullptr; + } + + CounterTracker& operator=(CounterTracker&& other) { + counter_ = other.counter_; + ec_ = other.ec_; + other.counter_ = nullptr; + other.ec_ = nullptr; + return *this; + } + + CounterTracker(const CounterTracker& other) + : counter_(other.counter_), ec_(other.ec_) { + counter_->fetch_add(1, std::memory_order_relaxed); + } + + CounterTracker& operator=(const CounterTracker&) = delete; + + private: + std::atomic* counter_{nullptr}; + EventCount* ec_{nullptr}; +}; + +template +class ThreadPoolTempl { + public: + typedef typename Environment::Task Task; + typedef RunQueue Queue; + + explicit ThreadPoolTempl(int num_threads, Environment env = Environment()) + : ThreadPoolTempl(num_threads, true, env) {} + + ThreadPoolTempl(int num_threads, bool allow_spinning, + Environment env = Environment()) + : env_(env), + num_threads_(num_threads), + allow_spinning_(allow_spinning), + thread_data_(num_threads), + global_steal_partition_(EncodePartition(0, num_threads_)), + blocked_(0), + spinning_(0), + done_(false), + cancelled_(false), + ec_(num_threads_), + wait_empty_(false), + wait_empty_ec_(1), + num_tasks_(0) { + // Calculate coprimes of all numbers [1, num_threads]. + // Coprimes are used for random walks over all threads in Steal + // and NonEmptyQueueIndex. Iteration is based on the fact that if we take + // a random starting thread index t and calculate num_threads - 1 subsequent + // indices as (t + coprime) % num_threads, we will cover all threads without + // repetitions (effectively getting a presudo-random permutation of thread + // indices). + assert(num_threads_ >= 1 && num_threads_ < kMaxThreads); + all_coprimes_.reserve(num_threads_); + for (int i = 1; i <= num_threads_; ++i) { + all_coprimes_.emplace_back(); + all_coprimes_.back().push_back(i); + ComputeCoprimes(i, &(all_coprimes_.back())); + } + for (int i = 0; i < num_threads_; i++) { + SetStealPartition(i, EncodePartition(0, num_threads_)); + thread_data_[i].thread.reset( + env_.CreateThread([this, i]() { WorkerLoop(i); })); + } + } + + ~ThreadPoolTempl() { + done_ = true; + + // Now if all threads block without work, they will start exiting. + // But note that threads can continue to work arbitrary long, + // block, submit new work, unblock and otherwise live full life. + if (!cancelled_) { + ec_.Notify(true); + } else { + // Since we were cancelled, there might be entries in the queues. + // Empty them to prevent their destructor from asserting. + for (size_t i = 0; i < thread_data_.size(); i++) { + thread_data_[i].queue.Flush(); + } + } + // Join threads explicitly (by destroying) to avoid destruction order within + // this class. + for (size_t i = 0; i < thread_data_.size(); ++i) { + thread_data_[i].thread.reset(); + } + } + + void SetStealPartitions( + const std::vector>& partitions) { + assert(partitions.size() == static_cast(num_threads_)); + + // Pass this information to each thread queue. + for (int i = 0; i < num_threads_; i++) { + const auto& pair = partitions[i]; + unsigned start = pair.first, end = pair.second; + AssertBounds(start, end); + unsigned val = EncodePartition(start, end); + SetStealPartition(i, val); + } + } + + void AddTask(std::function fn) { + AddTaskWithHint(std::move(fn), 0, num_threads_); + } + + void AddTaskWithHint(std::function fn, int start, int limit) { + Task t = env_.CreateTask([ + task = std::move(fn), raii = CounterTracker(&num_tasks_, &wait_empty_ec_) + ]() mutable { task(); }); + PerThread* pt = GetPerThread(); + if (pt->pool == this) { + // Worker thread of this pool, push onto the thread's queue. + Queue& q = thread_data_[pt->thread_id].queue; + t = q.PushFront(std::move(t)); + } else if (wait_empty_.load() == false) { + // A free-standing thread (or worker of another pool), push onto a random + // queue. + assert(start < limit); + assert(limit <= num_threads_); + int num_queues = limit - start; + int rnd = Rand(&pt->rand) % num_queues; + assert(start + rnd < limit); + Queue& q = thread_data_[start + rnd].queue; + t = q.PushBack(std::move(t)); + } + // Note: below we touch this after making w available to worker threads. + // Strictly speaking, this can lead to a racy-use-after-free. Consider that + // Schedule is called from a thread that is neither main thread nor a worker + // thread of this pool. Then, execution of w directly or indirectly + // completes overall computations, which in turn leads to destruction of + // this. We expect that such scenario is prevented by program, that is, + // this is kept alive while any threads can potentially be in Schedule. + if (!t.f) { + ec_.Notify(false); + } else { + env_.ExecuteTask(t); // Push failed, execute directly. + } + } + + void WaitQueueEmpty() { + bool waiting = wait_empty_.load(); + assert(waiting == false); + if (waiting || + !wait_empty_.compare_exchange_strong(waiting, true, + std::memory_order_acquire)) { + abort(); + } + EventCount::Waiter* w = wait_empty_ec_.GetWaiter(0); + wait_empty_ec_.Prewait(); + if (num_tasks_.load() == 0) { + wait_empty_ec_.CancelWait(); + } else { + wait_empty_ec_.CommitWait(w); + } + waiting = true; + if (!waiting || + !wait_empty_.compare_exchange_strong(waiting, false, + std::memory_order_acquire)) { + abort(); + } + } + + void Cancel() { + cancelled_ = true; + done_ = true; + + // Wake up the threads without work to let them exit on their own. + ec_.Notify(true); + } + + size_t NumThreads() const { return num_threads_; } + + int CurrentThreadId() const { + const PerThread* pt = const_cast(this)->GetPerThread(); + if (pt->pool == this) { + return pt->thread_id; + } else { + return -1; + } + } + + private: + // Create a single atomic that encodes start and limit information for + // each thread. + // We expect num_threads_ < 65536, so we can store them in a single + // std::atomic. + // Exposed publicly as static functions so that external callers can reuse + // this encode/decode logic for maintaining their own thread-safe copies of + // scheduling and steal domain(s). + static const int kMaxPartitionBits = 16; + static const int kMaxThreads = 1 << kMaxPartitionBits; + + inline unsigned EncodePartition(unsigned start, unsigned limit) { + return (start << kMaxPartitionBits) | limit; + } + + inline void DecodePartition(unsigned val, unsigned* start, unsigned* limit) { + *limit = val & (kMaxThreads - 1); + val >>= kMaxPartitionBits; + *start = val; + } + + void AssertBounds(int start, int end) { + assert(start >= 0); + assert(start < end); // non-zero sized partition + assert(end <= num_threads_); + } + + inline void SetStealPartition(size_t i, unsigned val) { + thread_data_[i].steal_partition.store(val, std::memory_order_relaxed); + } + + inline unsigned GetStealPartition(int i) { + return thread_data_[i].steal_partition.load(std::memory_order_relaxed); + } + + inline void ComputeCoprimes(int n, std::vector* coprimes) { + for (int i = 1; i <= n; i++) { + unsigned a = i; + unsigned b = n; + // If GCD(a, b) == 1, then a and b are coprimes. + while (b != 0) { + unsigned tmp = a; + a = b; + b = tmp % b; + } + if (a == 1) { + coprimes->push_back(i); + } + } + } + + typedef typename Environment::EnvThread Thread; + + struct PerThread { + constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {} + ThreadPoolTempl* pool; // Parent pool, or null for normal threads. + uint64_t rand; // Random generator state. + int thread_id; // Worker thread index in pool. + }; + + struct ThreadData { + constexpr ThreadData() : thread(), steal_partition(0), queue() {} + std::unique_ptr thread; + std::atomic steal_partition; + Queue queue; + }; + + Environment env_; + const int num_threads_; + const bool allow_spinning_; + std::vector thread_data_; + std::vector> all_coprimes_; + unsigned global_steal_partition_; + std::atomic blocked_; + std::atomic spinning_; + std::atomic done_; + std::atomic cancelled_; + EventCount ec_; + + std::atomic wait_empty_; + EventCount wait_empty_ec_; + std::atomic num_tasks_; + + // Main worker thread loop. + void WorkerLoop(int thread_id) { + PerThread* pt = GetPerThread(); + pt->pool = this; + pt->rand = GlobalThreadIdHash(); + pt->thread_id = thread_id; + Queue& q = thread_data_[thread_id].queue; + EventCount::Waiter* waiter = ec_.GetWaiter(thread_id); + // TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is + // proportional to num_threads_ and we assume that new work is scheduled at + // a constant rate, so we set spin_count to 5000 / num_threads_. The + // constant was picked based on a fair dice roll, tune it. + const int spin_count = + allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0; + if (num_threads_ == 1) { + // For num_threads_ == 1 there is no point in going through the expensive + // steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the + // victim queues it might reverse the order in which ops are executed + // compared to the order in which they are added, which tends to be + // counter-productive for the types of I/O workloads the single thread + // pools tend to be used for. + while (!cancelled_) { + Task t = q.PopFront(); + for (int i = 0; i < spin_count && !t.f; i++) { + if (!cancelled_.load(std::memory_order_relaxed)) { + t = q.PopFront(); + } + } + if (!t.f) { + if (!WaitForWork(waiter, &t)) { + return; + } + } + if (t.f) { + env_.ExecuteTask(t); + } + } + } else { + while (!cancelled_) { + Task t = q.PopFront(); + if (!t.f) { + t = LocalSteal(); + if (!t.f) { + t = GlobalSteal(); + if (!t.f) { + // Leave one thread spinning. This reduces latency. + if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) { + for (int i = 0; i < spin_count && !t.f; i++) { + if (!cancelled_.load(std::memory_order_relaxed)) { + t = GlobalSteal(); + } else { + return; + } + } + spinning_ = false; + } + if (!t.f) { + if (!WaitForWork(waiter, &t)) { + return; + } + } + } + } + } + if (t.f) { + env_.ExecuteTask(t); + } + } + } + } + + // Steal tries to steal work from other worker threads in the range [start, + // limit) in best-effort manner. + Task Steal(unsigned start, unsigned limit) { + PerThread* pt = GetPerThread(); + const size_t size = limit - start; + unsigned r = Rand(&pt->rand); + // Reduce r into [0, size) range, this utilizes trick from + // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + assert(all_coprimes_[size - 1].size() < (1 << 30)); + unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32; + unsigned index = + ((uint64_t)all_coprimes_[size - 1].size() * (uint64_t)r) >> 32; + unsigned inc = all_coprimes_[size - 1][index]; + + for (unsigned i = 0; i < size; i++) { + assert(start + victim < limit); + Task t = thread_data_[start + victim].queue.PopBack(); + if (t.f) { + return t; + } + victim += inc; + if (victim >= size) { + victim -= size; + } + } + return Task(); + } + + // Steals work within threads belonging to the partition. + Task LocalSteal() { + PerThread* pt = GetPerThread(); + unsigned partition = GetStealPartition(pt->thread_id); + // If thread steal partition is the same as global partition, there is no + // need to go through the steal loop twice. + if (global_steal_partition_ == partition) return Task(); + unsigned start, limit; + DecodePartition(partition, &start, &limit); + AssertBounds(start, limit); + + return Steal(start, limit); + } + + // Steals work from any other thread in the pool. + Task GlobalSteal() { return Steal(0, num_threads_); } + + // WaitForWork blocks until new work is available (returns true), or if it is + // time to exit (returns false). Can optionally return a task to execute in t + // (in such case t.f != nullptr on return). + bool WaitForWork(EventCount::Waiter* waiter, Task* t) { + assert(t != nullptr && !t->f); + // We already did best-effort emptiness check in Steal, so prepare for + // blocking. + ec_.Prewait(); + // Now do a reliable emptiness check. + int victim = NonEmptyQueueIndex(); + if (victim != -1) { + ec_.CancelWait(); + if (cancelled_) { + return false; + } else { + *t = thread_data_[victim].queue.PopBack(); + return true; + } + } + // Number of blocked threads is used as termination condition. + // If we are shutting down and all worker threads blocked without work, + // that's we are done. + blocked_++; + if (done_ && blocked_ == static_cast(num_threads_)) { + ec_.CancelWait(); + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + return true; + } + // Reached stable termination state. + ec_.Notify(true); + return false; + } + ec_.CommitWait(waiter); + blocked_--; + return true; + } + + int NonEmptyQueueIndex() { + PerThread* pt = GetPerThread(); + // We intentionally design NonEmptyQueueIndex to steal work from + // anywhere in the queue so threads don't block in WaitForWork() forever + // when all threads in their partition go to sleep. Steal is still local. + const size_t size = thread_data_.size(); + unsigned r = Rand(&pt->rand); + unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()]; + unsigned victim = r % size; + for (unsigned i = 0; i < size; i++) { + if (!thread_data_[victim].queue.Empty()) { + return victim; + } + victim += inc; + if (victim >= size) { + victim -= size; + } + } + return -1; + } + + static inline uint64_t GlobalThreadIdHash() { + return std::hash()(std::this_thread::get_id()); + } + + inline PerThread* GetPerThread() { + static thread_local PerThread per_thread_; + PerThread* pt = &per_thread_; + return pt; + } + + static inline unsigned Rand(uint64_t* state) { + uint64_t current = *state; + // Update the internal state + *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL; + // Generate the random output (using the PCG-XSH-RS scheme) + return static_cast((current ^ (current >> 22)) >> + (22 + (current >> 61))); + } +}; + +using NonblockingThreadPool = ThreadPoolTempl; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/run_queue.h b/paddle/fluid/framework/new_executor/run_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..ead37fa4695f5af83b1733ab0155cb35a94c39e8 --- /dev/null +++ b/paddle/fluid/framework/new_executor/run_queue.h @@ -0,0 +1,267 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Dmitry Vyukov +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// RunQueue is a fixed-size, partially non-blocking deque or Work items. +// Operations on front of the queue must be done by a single thread (owner), +// operations on back of the queue can be done by multiple threads concurrently. +// +// Algorithm outline: +// All remote threads operating on the queue back are serialized by a mutex. +// This ensures that at most two threads access state: owner and one remote +// thread (Size aside). The algorithm ensures that the occupied region of the +// underlying array is logically continuous (can wraparound, but no stray +// occupied elements). Owner operates on one end of this region, remote thread +// operates on the other end. Synchronization between these threads +// (potential consumption of the last element and take up of the last empty +// element) happens by means of state variable in each element. States are: +// empty, busy (in process of insertion of removal) and ready. Threads claim +// elements (empty->busy and ready->busy transitions) by means of a CAS +// operation. The finishing transition (busy->empty and busy->ready) are done +// with plain store as the element is exclusively owned by the current thread. +// +// Note: we could permit only pointers as elements, then we would not need +// separate state variable as null/non-null pointer value would serve as state, +// but that would require malloc/free per operation for large, complex values +// (and this is designed to store std::function<()>). + +#pragma once + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { + +template +class RunQueue { + public: + RunQueue() : front_(0), back_(0) { + // require power-of-two for fast masking + static_assert((kSize & (kSize - 1)) == 0, + "need to be a power of two for fast masking"); + static_assert( + kSize > 2, + "need to be in [4, 65536] range to leave enough space for counter"); + static_assert( + kSize <= (64 << 10), + "need to be in [4, 65536] range to leave enough space for counter"); + for (unsigned i = 0; i < kSize; i++) + array_[i].state.store(kEmpty, std::memory_order_relaxed); + } + + RunQueue(const RunQueue&) = delete; + void operator=(const RunQueue&) = delete; + + ~RunQueue() { assert(Size() == 0); } + + // PushFront inserts w at the beginning of the queue. + // If queue is full returns w, otherwise returns default-constructed Work. + Work PushFront(Work w) { + unsigned front = front_.load(std::memory_order_relaxed); + Elem* e = &array_[front & kMask]; + uint8_t s = e->state.load(std::memory_order_relaxed); + if (s != kEmpty || + !e->state.compare_exchange_strong(s, kBusy, + std::memory_order_acquire)) { + return w; + } + front_.store(front + 1 + (kSize << 1), std::memory_order_relaxed); + e->w = std::move(w); + e->state.store(kReady, std::memory_order_release); + return Work(); + } + + // PopFront removes and returns the first element in the queue. + // If the queue was empty returns default-constructed Work. + Work PopFront() { + unsigned front = front_.load(std::memory_order_relaxed); + Elem* e = &array_[(front - 1) & kMask]; + uint8_t s = e->state.load(std::memory_order_relaxed); + if (s != kReady || + !e->state.compare_exchange_strong(s, kBusy, + std::memory_order_acquire)) { + return Work(); + } + Work w = std::move(e->w); + e->state.store(kEmpty, std::memory_order_release); + front = ((front - 1) & kMask2) | (front & ~kMask2); + front_.store(front, std::memory_order_relaxed); + return w; + } + + // PushBack adds w at the end of the queue. + // If queue is full returns w, otherwise returns default-constructed Work. + Work PushBack(Work w) { + std::unique_lock lock(mutex_); + unsigned back = back_.load(std::memory_order_relaxed); + Elem* e = &array_[(back - 1) & kMask]; + uint8_t s = e->state.load(std::memory_order_relaxed); + if (s != kEmpty || + !e->state.compare_exchange_strong(s, kBusy, + std::memory_order_acquire)) { + return w; + } + back = ((back - 1) & kMask2) | (back & ~kMask2); + back_.store(back, std::memory_order_relaxed); + e->w = std::move(w); + e->state.store(kReady, std::memory_order_release); + return Work(); + } + + // PopBack removes and returns the last elements in the queue. + Work PopBack() { + if (Empty()) { + return Work(); + } + + std::unique_lock lock(mutex_); + unsigned back = back_.load(std::memory_order_relaxed); + Elem* e = &array_[back & kMask]; + uint8_t s = e->state.load(std::memory_order_relaxed); + if (s != kReady || + !e->state.compare_exchange_strong(s, kBusy, + std::memory_order_acquire)) { + return Work(); + } + Work w = std::move(e->w); + e->state.store(kEmpty, std::memory_order_release); + back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed); + return w; + } + + // PopBackHalf removes and returns half last elements in the queue. + // Returns number of elements removed. + unsigned PopBackHalf(std::vector* result) { + if (Empty()) { + return 0; + } + + std::unique_lock lock(mutex_); + unsigned back = back_.load(std::memory_order_relaxed); + unsigned size = Size(); + unsigned mid = back; + if (size > 1) mid = back + (size - 1) / 2; + unsigned n = 0; + unsigned start = 0; + for (; static_cast(mid - back) >= 0; mid--) { + Elem* e = &array_[mid & kMask]; + uint8_t s = e->state.load(std::memory_order_relaxed); + if (n == 0) { + if (s != kReady || + !e->state.compare_exchange_strong(s, kBusy, + std::memory_order_acquire)) + continue; + start = mid; + } else { + // Note: no need to store temporal kBusy, we exclusively own these + // elements. + assert(s == kReady); + } + result->push_back(std::move(e->w)); + e->state.store(kEmpty, std::memory_order_release); + n++; + } + if (n != 0) { + back_.store(start + 1 + (kSize << 1), std::memory_order_relaxed); + } + return n; + } + + // Size returns current queue size. + // Can be called by any thread at any time. + unsigned Size() const { return SizeOrNotEmpty(); } + + // Empty tests whether container is empty. + // Can be called by any thread at any time. + bool Empty() const { return SizeOrNotEmpty() == 0; } + + // Delete all the elements from the queue. + void Flush() { + while (!Empty()) { + PopFront(); + } + } + + private: + static const unsigned kMask = kSize - 1; + static const unsigned kMask2 = (kSize << 1) - 1; + struct Elem { + std::atomic state; + Work w; + }; + enum { + kEmpty, + kBusy, + kReady, + }; + + std::mutex mutex_; + // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of + // front/back, respectively. The remaining bits contain modification counters + // that are incremented on Push operations. This allows us to (1) distinguish + // between empty and full conditions (if we would use log(kSize) bits for + // position, these conditions would be indistinguishable); (2) obtain + // consistent snapshot of front_/back_ for Size operation using the + // modification counters. + std::atomic front_; + std::atomic back_; + Elem array_[kSize]; + + // SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false, + // only whether the size is 0 is guaranteed to be correct. + // Can be called by any thread at any time. + template + unsigned SizeOrNotEmpty() const { + // Emptiness plays critical role in thread pool blocking. So we go to great + // effort to not produce false positives (claim non-empty queue as empty). + unsigned front = front_.load(std::memory_order_acquire); + for (;;) { + // Capture a consistent snapshot of front/tail. + unsigned back = back_.load(std::memory_order_acquire); + unsigned front1 = front_.load(std::memory_order_relaxed); + if (front != front1) { + front = front1; + std::atomic_thread_fence(std::memory_order_acquire); + continue; + } + if (NeedSizeEstimate) { + return CalculateSize(front, back); + } else { + // This value will be 0 if the queue is empty, and undefined otherwise. + unsigned maybe_zero = ((front ^ back) & kMask2); + // Queue size estimate must agree with maybe zero check on the queue + // empty/non-empty state. + assert((CalculateSize(front, back) == 0) == (maybe_zero == 0)); + return maybe_zero; + } + } + } + + inline unsigned CalculateSize(unsigned front, unsigned back) const { + int size = (front & kMask2) - (back & kMask2); + // Fix overflow. + if (size < 0) { + size += 2 * kSize; + } + // Order of modification in push/pop is crafted to make the queue look + // larger than it is during concurrent modifications. E.g. push can + // increment size before the corresponding pop has decremented it. + // So the computed size can be up to kSize + 1, fix it. + if (size > static_cast(kSize)) { + size = kSize; + } + return static_cast(size); + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/thread_environment.h b/paddle/fluid/framework/new_executor/thread_environment.h new file mode 100644 index 0000000000000000000000000000000000000000..be936274186f4fb2ed2f6cb47ae6d1c99f2271c1 --- /dev/null +++ b/paddle/fluid/framework/new_executor/thread_environment.h @@ -0,0 +1,42 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#pragma once + +#include +#include + +namespace paddle { +namespace framework { + +struct StlThreadEnvironment { + struct Task { + std::function f; + }; + + // EnvThread constructor must start the thread, + // destructor must join the thread. + class EnvThread { + public: + explicit EnvThread(std::function f) : thr_(std::move(f)) {} + ~EnvThread() { thr_.join(); } + + private: + std::thread thr_; + }; + + EnvThread* CreateThread(std::function f) { + return new EnvThread(std::move(f)); + } + Task CreateTask(std::function f) { return Task{std::move(f)}; } + void ExecuteTask(const Task& t) { t.f(); } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue.cc new file mode 100644 index 0000000000000000000000000000000000000000..6586b418ba2781c5cdc46a5e0a221662e076557e --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "paddle/fluid/framework/new_executor/workqueue.h" +#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" + +namespace paddle { +namespace framework { + +class SingleThreadedWorkQueue : public WorkQueue { + public: + SingleThreadedWorkQueue() : queue_(1) {} + + SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete; + + SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete; + + virtual ~SingleThreadedWorkQueue() = default; + + void AddTask(std::function fn) override { + queue_.AddTask(std::move(fn)); + } + + void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } + + size_t NumThreads() override { return queue_.NumThreads(); } + + private: + NonblockingThreadPool queue_; +}; + +std::unique_ptr CreateSingleThreadedWorkQueue() { + std::unique_ptr ptr(new SingleThreadedWorkQueue); + return std::move(ptr); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue.h new file mode 100644 index 0000000000000000000000000000000000000000..402bd215d63167ba2e25c72efe9efebbc7b7f862 --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue.h @@ -0,0 +1,45 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace framework { + +class WorkQueue { + public: + WorkQueue() = default; + + WorkQueue(const WorkQueue&) = delete; + + WorkQueue& operator=(const WorkQueue&) = delete; + + virtual ~WorkQueue() = default; + + virtual void AddTask(std::function fn) = 0; + + virtual void WaitQueueEmpty() = 0; + + virtual size_t NumThreads() = 0; +}; + +std::unique_ptr CreateSingleThreadedWorkQueue(); + +std::unique_ptr CreateMultiThreadedWorkQueue(int num_threads); + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5d0a7f4033ad5b39311e15fb63d76614240dc76 --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/workqueue.h" +#include +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(WorkQueue, TestSingleThreadedWorkQueue) { + VLOG(1) << "In Test"; + using paddle::framework::WorkQueue; + using paddle::framework::CreateSingleThreadedWorkQueue; + std::atomic finished{false}; + std::atomic counter{0}; + constexpr unsigned kLoopNum = 1000000; + // CreateSingleThreadedWorkQueue + std::unique_ptr work_queue = CreateSingleThreadedWorkQueue(); + // NumThreads + EXPECT_EQ(work_queue->NumThreads(), 1u); + // AddTask + EXPECT_EQ(finished.load(), false); + EXPECT_EQ(counter.load(), 0u); + work_queue->AddTask([&counter, &finished, kLoopNum]() { + for (unsigned i = 0; i < kLoopNum; ++i) { + ++counter; + } + finished = true; + }); + // WaitQueueEmpty + EXPECT_EQ(finished.load(), false); + work_queue->WaitQueueEmpty(); + EXPECT_EQ(finished.load(), true); + EXPECT_EQ(counter.load(), kLoopNum); +}