未验证 提交 cdb9bfa3 编写于 作者: L liutiexing 提交者: GitHub

[new-exec] Add events waiter (#36480)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* update

* update

* update Error MSG

* update EventsWaiter
上级 eff3ee5e
......@@ -50,11 +50,13 @@
#include <cstdlib>
#include <mutex>
#include <vector>
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
namespace paddle {
namespace framework {
void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr);
class EventCount {
public:
class Waiter;
......
......@@ -37,7 +37,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
main_program_(main_prog),
global_scope_(global_scope),
stream_analyzer_(place),
async_work_queue_(kHostNumThreads) {
async_work_queue_(kHostNumThreads, &main_thread_blocker_) {
is_build_ = false;
feed_names_ = feed_names;
......@@ -367,7 +367,8 @@ void InterpreterCore::ExecuteInstructionList(
}
}
async_work_queue_.WaitEmpty();
auto event_id = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id;
PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(),
......
......@@ -95,6 +95,7 @@ class InterpreterCore {
InterpreterProfiler dry_run_profiler_;
StreamAnalyzer stream_analyzer_;
EventManager event_manager_;
EventsWaiter main_thread_blocker_;
interpretercore::AsyncWorkQueue async_work_queue_;
InterpreterCoreGarbageCollector gc_;
......
......@@ -33,6 +33,7 @@
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -53,16 +54,19 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class AsyncWorkQueue {
public:
explicit AsyncWorkQueue(size_t host_num_threads)
AsyncWorkQueue(size_t host_num_threads, EventsWaiter* waiter)
: host_num_thread_(host_num_threads) {
std::vector<WorkQueueOptions> group_options;
// for execute host Kernel
group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*track_task*/ true);
/*track_task*/ true,
/*queue_empty_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true, /*track_task*/ true);
/*allow_spinning*/ true,
/*track_task*/ true,
/*queue_empty_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options);
}
......@@ -71,7 +75,7 @@ class AsyncWorkQueue {
AtomicVectorSizeT& PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
// void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
void AddTask(const OpFuncType& op_func_type, std::function<void()> fn) {
queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
......
......@@ -19,9 +19,12 @@
namespace paddle {
namespace framework {
template <typename Notifier>
class TaskTracker {
public:
TaskTracker() : wait_empty_cv_(1) {}
TaskTracker() = default;
explicit TaskTracker(Notifier& notifier) : notifier_(&notifier) {}
TaskTracker(const TaskTracker&) = delete;
......@@ -33,32 +36,17 @@ class TaskTracker {
void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
wait_empty_cv_.Notify(true);
if (notifier_ != nullptr) {
notifier_->NotifyEvent();
}
}
}
// only one user can wait at any time
void WaitTaskNumToZero() {
bool waiting = false;
if (!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
abort();
}
EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0);
wait_empty_cv_.Prewait();
if (num_tasks_.load(std::memory_order_relaxed) == 0) {
wait_empty_cv_.CancelWait();
} else {
wait_empty_cv_.CommitWait(w);
}
wait_empty_.store(false);
}
uint64_t PendingTaskNum() { return num_tasks_.load(); }
private:
alignas(64) std::atomic<uint64_t> num_tasks_{0};
alignas(64) EventCount wait_empty_cv_;
alignas(64) std::atomic<bool> wait_empty_{false};
Notifier* notifier_{nullptr};
};
template <typename Environment>
......
......@@ -13,13 +13,18 @@ namespace paddle {
namespace framework {
namespace {
using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>;
class WorkQueueImpl : public WorkQueue {
public:
explicit WorkQueueImpl(const WorkQueueOptions& options)
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) {
if (options_.track_task) {
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.queue_empty_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker;
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get());
}
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
......@@ -44,20 +49,11 @@ class WorkQueueImpl : public WorkQueue {
queue_->AddTask(std::move(fn));
}
void WaitQueueEmpty() override {
if (tracker_ == nullptr) {
PADDLE_THROW(
platform::errors::Unavailable("set WorkQueueOptions.track_task = "
"true before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
size_t NumThreads() const override { return queue_->NumThreads(); }
private:
NonblockingThreadPool* queue_;
TaskTracker* tracker_;
NonblockingThreadPool* queue_{nullptr};
TaskTracker* tracker_{nullptr};
};
class WorkQueueGroupImpl : public WorkQueueGroup {
......@@ -69,8 +65,6 @@ class WorkQueueGroupImpl : public WorkQueueGroup {
void AddTask(size_t queue_idx, std::function<void()> fn) override;
void WaitQueueGroupEmpty() override;
size_t QueueNumThreads(size_t queue_idx) const override;
size_t QueueGroupNumThreads() const override;
......@@ -92,9 +86,14 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) {
if (options.track_task && tracker_ == nullptr &&
options.queue_empty_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker;
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get());
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
......@@ -124,15 +123,6 @@ void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
queues_[queue_idx]->AddTask(std::move(fn));
}
void WorkQueueGroupImpl::WaitQueueGroupEmpty() {
if (nullptr == tracker_) {
PADDLE_THROW(platform::errors::Unavailable(
"set WorkQueueOptions.track_task = true for at least one of queues "
"before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
assert(queue_idx < queues_.size());
return queues_.at(queue_idx)->NumThreads();
......
......@@ -21,15 +21,30 @@
namespace paddle {
namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty";
class EventsWaiter;
struct WorkQueueOptions {
WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task)
: num_threads(num_threads),
allow_spinning(allow_spinning),
track_task(track_task) {}
WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task,
EventsWaiter* waiter)
: num_threads(num_threads),
allow_spinning(allow_spinning),
track_task(track_task),
queue_empty_waiter(waiter) {}
size_t num_threads;
bool allow_spinning;
// If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty")
// occured.
bool track_task;
EventsWaiter* queue_empty_waiter{nullptr}; // not owned
};
class WorkQueue {
......@@ -44,9 +59,8 @@ class WorkQueue {
virtual void AddTask(std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true before call this
// interface, otherwise will abort()
virtual void WaitQueueEmpty() = 0;
// See WorkQueueOptions.track_task for details
// virtual void WaitQueueEmpty() = 0;
virtual size_t NumThreads() const = 0;
......@@ -67,9 +81,8 @@ class WorkQueueGroup {
virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true for at least one of queues
// before call this interface, otherwise will abort()
virtual void WaitQueueGroupEmpty() = 0;
// See WorkQueueOptions.track_task for details
// virtual void WaitQueueGroupEmpty() = 0;
virtual size_t QueueNumThreads(size_t queue_idx) const = 0;
......
......@@ -16,18 +16,21 @@
#include <atomic>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
TEST(WorkQueue, TestSingleThreadedWorkQueue) {
VLOG(1) << "In Test";
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue;
using paddle::framework::CreateSingleThreadedWorkQueue;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0};
constexpr unsigned kLoopNum = 1000000;
// CreateSingleThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true);
/*track_task*/ true, &events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options);
// NumThreads
EXPECT_EQ(work_queue->NumThreads(), 1u);
......@@ -42,7 +45,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
});
// WaitQueueEmpty
EXPECT_EQ(finished.load(), false);
work_queue->WaitQueueEmpty();
events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum);
}
......@@ -52,13 +55,15 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue;
using paddle::framework::CreateMultiThreadedWorkQueue;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000;
// CreateMultiThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true);
/*track_task*/ true, &events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options);
// NumThreads
EXPECT_EQ(work_queue->NumThreads(), 10u);
......@@ -75,7 +80,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
}
// WaitQueueEmpty
EXPECT_EQ(finished.load(), false);
work_queue->WaitQueueEmpty();
events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
}
......@@ -84,15 +89,17 @@ TEST(WorkQueue, TestWorkQueueGroup) {
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueueGroup;
using paddle::framework::CreateWorkQueueGroup;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000;
// CreateMultiThreadedWorkQueue
// ThreadedWorkQueueGroup
EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true);
/*track_task*/ true, &events_waiter);
WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true);
/*track_task*/ true, &events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
// NumThreads
EXPECT_EQ(queue_group->QueueNumThreads(0), 1u);
......@@ -113,6 +120,6 @@ TEST(WorkQueue, TestWorkQueueGroup) {
}
});
// WaitQueueGroupEmpty()
queue_group->WaitQueueGroupEmpty();
events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
}
......@@ -55,5 +55,62 @@ void AlignedFree(void* mem_ptr) {
#endif
}
constexpr EventsWaiter::EventId kEmptyEventId = -1;
EventsWaiter::EventsWaiter()
: trigger_event_(kEmptyEventId), waiting_(false), cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
names_.emplace_back(name);
checkers_.emplace_back(std::move(checker));
EventId id = checkers_.size() - 1;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
notifiers_.emplace_back(notifier);
return notifier;
}
std::string EventsWaiter::WaitEvent() {
// only one user can wait at any time
bool waiting = false;
if (!waiting_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
EventId id = kEmptyEventId;
auto w = cv_.GetWaiter(0);
cv_.Prewait();
int64_t event_num = checkers_.size();
for (int64_t i = 0; id == kEmptyEventId && i < event_num; ++i) {
if (checkers_[i]()) {
id = i;
}
}
if (id != kEmptyEventId) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
id = trigger_event_.load(std::memory_order_relaxed);
}
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return names_.at(id);
}
void EventsWaiter::SetTriggerEvent(const EventId& id) {
trigger_event_.store(id, std::memory_order_relaxed);
cv_.Notify(true);
}
std::string EventsWaiter::EventNotifier::GetEventName() {
return waiter_.names_.at(id_);
}
void EventsWaiter::EventNotifier::NotifyEvent() {
waiter_.SetTriggerEvent(id_);
}
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,11 @@
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/new_executor/event_count.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -64,5 +69,56 @@ void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr);
// A multiplexing waiter, be able to wait multi events simultaneously.
// Blocking the calling thread to wait any of the registered events.
// Non-thread-safe.
class EventsWaiter {
public:
using EventId = int64_t;
using EventChecker = std::function<bool()>;
class EventNotifier {
public:
void NotifyEvent();
EventId GetEventId() { return id_; }
std::string GetEventName();
private:
friend EventsWaiter;
EventNotifier(EventId id, EventsWaiter* waiter)
: id_(id), waiter_(*waiter) {}
EventId id_;
EventsWaiter& waiter_;
};
EventsWaiter();
EventsWaiter(const EventsWaiter&) = delete;
EventsWaiter& operator=(const EventsWaiter&) = delete;
// All the RegisterEvent functions must be called before any WaitEvent
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name,
EventChecker checker);
// Wait any of the registered events
std::string WaitEvent();
private:
friend EventNotifier;
void SetTriggerEvent(const EventId& id);
std::vector<std::string> names_;
std::vector<EventChecker> checkers_;
std::vector<std::shared_ptr<EventNotifier>> notifiers_;
std::atomic<EventId> trigger_event_;
std::atomic<bool> waiting_;
EventCount cv_;
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册