未验证 提交 4e7bd9c3 编写于 作者: L liutiexing 提交者: GitHub

Optimize workqueue (#35931)

* add align for WorkQueue

* WorkQueue update

* Revert "WorkQueue update"

This reverts commit 14ce793dbb204f8ddec63c34b3b72a73c7cdb93a.

* optimize WorkQueue
上级 d6d2dafa
......@@ -2,7 +2,7 @@ set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_f
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor)
cc_library(workqueue SRCS workqueue.cc DEPS enforce)
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS})
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue)
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog)
......
......@@ -50,6 +50,7 @@
#include <cstdlib>
#include <mutex>
#include <vector>
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
namespace paddle {
namespace framework {
......@@ -60,7 +61,7 @@ class EventCount {
explicit EventCount(size_t waiter_num) : state_(kStackMask) {
assert(waiter_num < (1 << kWaiterBits) - 1);
void* buffer = malloc(sizeof(Waiter) * waiter_num);
void* buffer = AlignedMalloc(sizeof(Waiter) * waiter_num, alignof(Waiter));
if (buffer == nullptr) {
return;
}
......@@ -78,7 +79,7 @@ class EventCount {
~EventCount() {
// Ensure there are no waiters.
assert(state_.load() == kStackMask);
free(waiters_);
AlignedFree(waiters_);
}
Waiter* GetWaiter(size_t waiter_index) {
......
......@@ -56,9 +56,9 @@ class TaskTracker {
}
private:
std::atomic<uint64_t> num_tasks_{0};
EventCount wait_empty_cv_;
std::atomic<bool> wait_empty_{false};
alignas(64) std::atomic<uint64_t> num_tasks_{0};
alignas(64) EventCount wait_empty_cv_;
alignas(64) std::atomic<bool> wait_empty_{false};
};
template <typename Environment>
......@@ -70,15 +70,16 @@ class ThreadPoolTempl {
ThreadPoolTempl(int num_threads, bool allow_spinning,
Environment env = Environment())
: env_(env),
num_threads_(num_threads),
allow_spinning_(allow_spinning),
thread_data_(num_threads),
global_steal_partition_(EncodePartition(0, num_threads_)),
blocked_(0),
num_tasks_(0),
spinning_(0),
done_(false),
cancelled_(false),
ec_(num_threads_) {
ec_(num_threads),
num_threads_(num_threads),
thread_data_(num_threads) {
// Calculate coprimes of all numbers [1, num_threads].
// Coprimes are used for random walks over all threads in Steal
// and NonEmptyQueueIndex. Iteration is based on the fact that if we take
......@@ -143,6 +144,7 @@ class ThreadPoolTempl {
void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
Task t = env_.CreateTask(std::move(fn));
PerThread* pt = GetPerThread();
uint64_t num_tasks = num_tasks_.fetch_add(1, std::memory_order_relaxed) + 1;
if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue.
Queue& q = thread_data_[pt->thread_id].queue;
......@@ -166,8 +168,11 @@ class ThreadPoolTempl {
// this. We expect that such scenario is prevented by program, that is,
// this is kept alive while any threads can potentially be in Schedule.
if (!t.f) {
ec_.Notify(false);
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) {
ec_.Notify(false);
}
} else {
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
env_.ExecuteTask(t); // Push failed, execute directly.
}
}
......@@ -259,16 +264,17 @@ class ThreadPoolTempl {
};
Environment env_;
const int num_threads_;
const bool allow_spinning_;
std::vector<ThreadData> thread_data_;
std::vector<std::vector<unsigned>> all_coprimes_;
unsigned global_steal_partition_;
std::atomic<unsigned> blocked_;
std::atomic<uint64_t> num_tasks_;
std::atomic<bool> spinning_;
std::atomic<bool> done_;
std::atomic<bool> cancelled_;
EventCount ec_;
const int num_threads_;
std::vector<ThreadData> thread_data_;
// Main worker thread loop.
void WorkerLoop(int thread_id) {
......@@ -305,6 +311,7 @@ class ThreadPoolTempl {
}
if (t.f) {
env_.ExecuteTask(t);
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
}
}
} else {
......@@ -315,8 +322,7 @@ class ThreadPoolTempl {
if (!t.f) {
t = GlobalSteal();
if (!t.f) {
// Leave one thread spinning. This reduces latency.
if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
if (allow_spinning_) {
for (int i = 0; i < spin_count && !t.f; i++) {
if (!cancelled_.load(std::memory_order_relaxed)) {
t = GlobalSteal();
......@@ -324,7 +330,6 @@ class ThreadPoolTempl {
return;
}
}
spinning_ = false;
}
if (!t.f) {
if (!WaitForWork(waiter, &t)) {
......@@ -336,6 +341,7 @@ class ThreadPoolTempl {
}
if (t.f) {
env_.ExecuteTask(t);
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
}
}
}
......
......@@ -204,7 +204,6 @@ class RunQueue {
kReady,
};
std::mutex mutex_;
// Low log(kSize) + 1 bits in front_ and back_ contain rolling index of
// front/back, respectively. The remaining bits contain modification counters
// that are incremented on Push operations. This allows us to (1) distinguish
......@@ -214,6 +213,7 @@ class RunQueue {
// modification counters.
alignas(64) std::atomic<unsigned> front_;
alignas(64) std::atomic<unsigned> back_;
std::mutex mutex_;
Elem array_[kSize];
// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
......
......@@ -18,14 +18,18 @@ class WorkQueueImpl : public WorkQueue {
explicit WorkQueueImpl(const WorkQueueOptions& options)
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) {
if (options_.track_task) {
tracker_ = new TaskTracker;
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker;
}
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
}
virtual ~WorkQueueImpl() {
delete tracker_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
AlignedFree(tracker_);
}
delete queue_;
}
......@@ -89,7 +93,8 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) {
tracker_ = new TaskTracker;
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker;
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
......@@ -100,7 +105,10 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
delete tracker_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
AlignedFree(tracker_);
}
free(queues_storage_);
}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include <cstdint>
#include <cstdlib>
namespace paddle {
namespace framework {
void* AlignedMalloc(size_t size, size_t alignment) {
assert(alignment >= sizeof(void*) && (alignment & (alignment - 1)) == 0);
size = (size + alignment - 1) / alignment * alignment;
#if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200112L
void* aligned_mem = nullptr;
if (posix_memalign(&aligned_mem, alignment, size) != 0) {
aligned_mem = nullptr;
}
return aligned_mem;
#elif defined(_WIN32)
return _aligned_malloc(size, alignment);
#else
void* mem = malloc(size + alignment);
if (mem == nullptr) {
return nullptr;
}
size_t adjust = alignment - reinterpret_cast<uint64_t>(mem) % alignment;
void* aligned_mem = reinterpret_cast<char*>(mem) + adjust;
*(reinterpret_cast<void**>(aligned_mem) - 1) = mem;
assert(reinterpret_cast<uint64_t>(aligned_mem) % alignment == 0);
return aligned_mem;
#endif
}
void AlignedFree(void* mem_ptr) {
#if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200112L
free(mem_ptr);
#elif defined(_WIN32)
_aligned_free(mem_ptr);
#else
if (mem_ptr) {
free(*(reinterpret_cast<void**>(mem_ptr) - 1));
}
#endif
}
} // namespace framework
} // namespace paddle
......@@ -59,5 +59,9 @@ class CounterGuard {
Holder* counter_holder_{nullptr};
};
void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr);
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册