未验证 提交 cdb9bfa3 编写于 作者: L liutiexing 提交者: GitHub

[new-exec] Add events waiter (#36480)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* update

* update

* update Error MSG

* update EventsWaiter
上级 eff3ee5e
...@@ -50,11 +50,13 @@ ...@@ -50,11 +50,13 @@
#include <cstdlib> #include <cstdlib>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr);
class EventCount { class EventCount {
public: public:
class Waiter; class Waiter;
......
...@@ -37,7 +37,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -37,7 +37,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
main_program_(main_prog), main_program_(main_prog),
global_scope_(global_scope), global_scope_(global_scope),
stream_analyzer_(place), stream_analyzer_(place),
async_work_queue_(kHostNumThreads) { async_work_queue_(kHostNumThreads, &main_thread_blocker_) {
is_build_ = false; is_build_ = false;
feed_names_ = feed_names; feed_names_ = feed_names;
...@@ -367,7 +367,8 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -367,7 +367,8 @@ void InterpreterCore::ExecuteInstructionList(
} }
} }
async_work_queue_.WaitEmpty(); auto event_id = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(), op_run_number_.load(), vec_instr.size(),
......
...@@ -95,6 +95,7 @@ class InterpreterCore { ...@@ -95,6 +95,7 @@ class InterpreterCore {
InterpreterProfiler dry_run_profiler_; InterpreterProfiler dry_run_profiler_;
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventManager event_manager_; EventManager event_manager_;
EventsWaiter main_thread_blocker_;
interpretercore::AsyncWorkQueue async_work_queue_; interpretercore::AsyncWorkQueue async_work_queue_;
InterpreterCoreGarbageCollector gc_; InterpreterCoreGarbageCollector gc_;
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -53,16 +54,19 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; ...@@ -53,16 +54,19 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
explicit AsyncWorkQueue(size_t host_num_threads) AsyncWorkQueue(size_t host_num_threads, EventsWaiter* waiter)
: host_num_thread_(host_num_threads) { : host_num_thread_(host_num_threads) {
std::vector<WorkQueueOptions> group_options; std::vector<WorkQueueOptions> group_options;
// for execute host Kernel // for execute host Kernel
group_options.emplace_back(/*num_threads*/ host_num_threads, group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*track_task*/ true); /*track_task*/ true,
/*queue_empty_waiter*/ waiter);
// for launch device Kernel // for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1, group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true, /*track_task*/ true); /*allow_spinning*/ true,
/*track_task*/ true,
/*queue_empty_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options); queue_group_ = CreateWorkQueueGroup(group_options);
} }
...@@ -71,7 +75,7 @@ class AsyncWorkQueue { ...@@ -71,7 +75,7 @@ class AsyncWorkQueue {
AtomicVectorSizeT& PrepareAtomicVarRef( AtomicVectorSizeT& PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info); const std::vector<VariableMetaInfo>& vec_meta_info);
void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } // void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
void AddTask(const OpFuncType& op_func_type, std::function<void()> fn) { void AddTask(const OpFuncType& op_func_type, std::function<void()> fn) {
queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn)); queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
......
...@@ -19,9 +19,12 @@ ...@@ -19,9 +19,12 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename Notifier>
class TaskTracker { class TaskTracker {
public: public:
TaskTracker() : wait_empty_cv_(1) {} TaskTracker() = default;
explicit TaskTracker(Notifier& notifier) : notifier_(&notifier) {}
TaskTracker(const TaskTracker&) = delete; TaskTracker(const TaskTracker&) = delete;
...@@ -33,32 +36,17 @@ class TaskTracker { ...@@ -33,32 +36,17 @@ class TaskTracker {
void SubCounter() { void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
wait_empty_cv_.Notify(true); if (notifier_ != nullptr) {
notifier_->NotifyEvent();
}
} }
} }
// only one user can wait at any time uint64_t PendingTaskNum() { return num_tasks_.load(); }
void WaitTaskNumToZero() {
bool waiting = false;
if (!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
abort();
}
EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0);
wait_empty_cv_.Prewait();
if (num_tasks_.load(std::memory_order_relaxed) == 0) {
wait_empty_cv_.CancelWait();
} else {
wait_empty_cv_.CommitWait(w);
}
wait_empty_.store(false);
}
private: private:
alignas(64) std::atomic<uint64_t> num_tasks_{0}; alignas(64) std::atomic<uint64_t> num_tasks_{0};
alignas(64) EventCount wait_empty_cv_; Notifier* notifier_{nullptr};
alignas(64) std::atomic<bool> wait_empty_{false};
}; };
template <typename Environment> template <typename Environment>
......
...@@ -13,13 +13,18 @@ namespace paddle { ...@@ -13,13 +13,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace { namespace {
using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>;
class WorkQueueImpl : public WorkQueue { class WorkQueueImpl : public WorkQueue {
public: public:
explicit WorkQueueImpl(const WorkQueueOptions& options) explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) { if (options_.track_task && options.queue_empty_waiter != nullptr) {
if (options_.track_task) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker; TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get());
} }
queue_ = new NonblockingThreadPool(options_.num_threads, queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning); options_.allow_spinning);
...@@ -44,20 +49,11 @@ class WorkQueueImpl : public WorkQueue { ...@@ -44,20 +49,11 @@ class WorkQueueImpl : public WorkQueue {
queue_->AddTask(std::move(fn)); queue_->AddTask(std::move(fn));
} }
void WaitQueueEmpty() override {
if (tracker_ == nullptr) {
PADDLE_THROW(
platform::errors::Unavailable("set WorkQueueOptions.track_task = "
"true before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
size_t NumThreads() const override { return queue_->NumThreads(); } size_t NumThreads() const override { return queue_->NumThreads(); }
private: private:
NonblockingThreadPool* queue_; NonblockingThreadPool* queue_{nullptr};
TaskTracker* tracker_; TaskTracker* tracker_{nullptr};
}; };
class WorkQueueGroupImpl : public WorkQueueGroup { class WorkQueueGroupImpl : public WorkQueueGroup {
...@@ -69,8 +65,6 @@ class WorkQueueGroupImpl : public WorkQueueGroup { ...@@ -69,8 +65,6 @@ class WorkQueueGroupImpl : public WorkQueueGroup {
void AddTask(size_t queue_idx, std::function<void()> fn) override; void AddTask(size_t queue_idx, std::function<void()> fn) override;
void WaitQueueGroupEmpty() override;
size_t QueueNumThreads(size_t queue_idx) const override; size_t QueueNumThreads(size_t queue_idx) const override;
size_t QueueGroupNumThreads() const override; size_t QueueGroupNumThreads() const override;
...@@ -92,9 +86,14 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -92,9 +86,14 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer); queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
for (size_t idx = 0; idx < num_queues; ++idx) { for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx]; const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) { if (options.track_task && tracker_ == nullptr &&
options.queue_empty_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
tracker_ = new (storage) TaskTracker; TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get());
} }
queues_[idx] = new (&queues_storage_[idx]) queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning); NonblockingThreadPool(options.num_threads, options.allow_spinning);
...@@ -124,15 +123,6 @@ void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) { ...@@ -124,15 +123,6 @@ void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
queues_[queue_idx]->AddTask(std::move(fn)); queues_[queue_idx]->AddTask(std::move(fn));
} }
void WorkQueueGroupImpl::WaitQueueGroupEmpty() {
if (nullptr == tracker_) {
PADDLE_THROW(platform::errors::Unavailable(
"set WorkQueueOptions.track_task = true for at least one of queues "
"before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const { size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
assert(queue_idx < queues_.size()); assert(queue_idx < queues_.size());
return queues_.at(queue_idx)->NumThreads(); return queues_.at(queue_idx)->NumThreads();
......
...@@ -21,15 +21,30 @@ ...@@ -21,15 +21,30 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty";
class EventsWaiter;
struct WorkQueueOptions { struct WorkQueueOptions {
WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task) WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task)
: num_threads(num_threads), : num_threads(num_threads),
allow_spinning(allow_spinning), allow_spinning(allow_spinning),
track_task(track_task) {} track_task(track_task) {}
WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task,
EventsWaiter* waiter)
: num_threads(num_threads),
allow_spinning(allow_spinning),
track_task(track_task),
queue_empty_waiter(waiter) {}
size_t num_threads; size_t num_threads;
bool allow_spinning; bool allow_spinning;
// If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty")
// occured.
bool track_task; bool track_task;
EventsWaiter* queue_empty_waiter{nullptr}; // not owned
}; };
class WorkQueue { class WorkQueue {
...@@ -44,9 +59,8 @@ class WorkQueue { ...@@ -44,9 +59,8 @@ class WorkQueue {
virtual void AddTask(std::function<void()> fn) = 0; virtual void AddTask(std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true before call this // See WorkQueueOptions.track_task for details
// interface, otherwise will abort() // virtual void WaitQueueEmpty() = 0;
virtual void WaitQueueEmpty() = 0;
virtual size_t NumThreads() const = 0; virtual size_t NumThreads() const = 0;
...@@ -67,9 +81,8 @@ class WorkQueueGroup { ...@@ -67,9 +81,8 @@ class WorkQueueGroup {
virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0; virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true for at least one of queues // See WorkQueueOptions.track_task for details
// before call this interface, otherwise will abort() // virtual void WaitQueueGroupEmpty() = 0;
virtual void WaitQueueGroupEmpty() = 0;
virtual size_t QueueNumThreads(size_t queue_idx) const = 0; virtual size_t QueueNumThreads(size_t queue_idx) const = 0;
......
...@@ -16,18 +16,21 @@ ...@@ -16,18 +16,21 @@
#include <atomic> #include <atomic>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
TEST(WorkQueue, TestSingleThreadedWorkQueue) { TEST(WorkQueue, TestSingleThreadedWorkQueue) {
VLOG(1) << "In Test"; VLOG(1) << "In Test";
using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue; using paddle::framework::WorkQueue;
using paddle::framework::CreateSingleThreadedWorkQueue; using paddle::framework::CreateSingleThreadedWorkQueue;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false}; std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0}; std::atomic<unsigned> counter{0};
constexpr unsigned kLoopNum = 1000000; constexpr unsigned kLoopNum = 1000000;
// CreateSingleThreadedWorkQueue // CreateSingleThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true, WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true); /*track_task*/ true, &events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options); auto work_queue = CreateSingleThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 1u); EXPECT_EQ(work_queue->NumThreads(), 1u);
...@@ -42,7 +45,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -42,7 +45,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
}); });
// WaitQueueEmpty // WaitQueueEmpty
EXPECT_EQ(finished.load(), false); EXPECT_EQ(finished.load(), false);
work_queue->WaitQueueEmpty(); events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum); EXPECT_EQ(counter.load(), kLoopNum);
} }
...@@ -52,13 +55,15 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -52,13 +55,15 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue; using paddle::framework::WorkQueue;
using paddle::framework::CreateMultiThreadedWorkQueue; using paddle::framework::CreateMultiThreadedWorkQueue;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false}; std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0}; std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000; constexpr unsigned kLoopNum = 1000000;
// CreateMultiThreadedWorkQueue // CreateMultiThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true, WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true); /*track_task*/ true, &events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options); auto work_queue = CreateMultiThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 10u); EXPECT_EQ(work_queue->NumThreads(), 10u);
...@@ -75,7 +80,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -75,7 +80,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
} }
// WaitQueueEmpty // WaitQueueEmpty
EXPECT_EQ(finished.load(), false); EXPECT_EQ(finished.load(), false);
work_queue->WaitQueueEmpty(); events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
} }
...@@ -84,15 +89,17 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -84,15 +89,17 @@ TEST(WorkQueue, TestWorkQueueGroup) {
using paddle::framework::WorkQueueOptions; using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueueGroup; using paddle::framework::WorkQueueGroup;
using paddle::framework::CreateWorkQueueGroup; using paddle::framework::CreateWorkQueueGroup;
using paddle::framework::EventsWaiter;
std::atomic<bool> finished{false}; std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0}; std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000; constexpr unsigned kLoopNum = 1000000;
// CreateMultiThreadedWorkQueue // ThreadedWorkQueueGroup
EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true, WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true); /*track_task*/ true, &events_waiter);
WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true, WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true); /*track_task*/ true, &events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
// NumThreads // NumThreads
EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); EXPECT_EQ(queue_group->QueueNumThreads(0), 1u);
...@@ -113,6 +120,6 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -113,6 +120,6 @@ TEST(WorkQueue, TestWorkQueueGroup) {
} }
}); });
// WaitQueueGroupEmpty() // WaitQueueGroupEmpty()
queue_group->WaitQueueGroupEmpty(); events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
} }
...@@ -55,5 +55,62 @@ void AlignedFree(void* mem_ptr) { ...@@ -55,5 +55,62 @@ void AlignedFree(void* mem_ptr) {
#endif #endif
} }
constexpr EventsWaiter::EventId kEmptyEventId = -1;
EventsWaiter::EventsWaiter()
: trigger_event_(kEmptyEventId), waiting_(false), cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
names_.emplace_back(name);
checkers_.emplace_back(std::move(checker));
EventId id = checkers_.size() - 1;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
notifiers_.emplace_back(notifier);
return notifier;
}
std::string EventsWaiter::WaitEvent() {
// only one user can wait at any time
bool waiting = false;
if (!waiting_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
EventId id = kEmptyEventId;
auto w = cv_.GetWaiter(0);
cv_.Prewait();
int64_t event_num = checkers_.size();
for (int64_t i = 0; id == kEmptyEventId && i < event_num; ++i) {
if (checkers_[i]()) {
id = i;
}
}
if (id != kEmptyEventId) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
id = trigger_event_.load(std::memory_order_relaxed);
}
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return names_.at(id);
}
void EventsWaiter::SetTriggerEvent(const EventId& id) {
trigger_event_.store(id, std::memory_order_relaxed);
cv_.Notify(true);
}
std::string EventsWaiter::EventNotifier::GetEventName() {
return waiter_.names_.at(id_);
}
void EventsWaiter::EventNotifier::NotifyEvent() {
waiter_.SetTriggerEvent(id_);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdlib> #include <cstdlib>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/new_executor/event_count.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -64,5 +69,56 @@ void* AlignedMalloc(size_t size, size_t alignment); ...@@ -64,5 +69,56 @@ void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr); void AlignedFree(void* memory_ptr);
// A multiplexing waiter, be able to wait multi events simultaneously.
// Blocking the calling thread to wait any of the registered events.
// Non-thread-safe.
class EventsWaiter {
public:
using EventId = int64_t;
using EventChecker = std::function<bool()>;
class EventNotifier {
public:
void NotifyEvent();
EventId GetEventId() { return id_; }
std::string GetEventName();
private:
friend EventsWaiter;
EventNotifier(EventId id, EventsWaiter* waiter)
: id_(id), waiter_(*waiter) {}
EventId id_;
EventsWaiter& waiter_;
};
EventsWaiter();
EventsWaiter(const EventsWaiter&) = delete;
EventsWaiter& operator=(const EventsWaiter&) = delete;
// All the RegisterEvent functions must be called before any WaitEvent
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name,
EventChecker checker);
// Wait any of the registered events
std::string WaitEvent();
private:
friend EventNotifier;
void SetTriggerEvent(const EventId& id);
std::vector<std::string> names_;
std::vector<EventChecker> checkers_;
std::vector<std::shared_ptr<EventNotifier>> notifiers_;
std::atomic<EventId> trigger_event_;
std::atomic<bool> waiting_;
EventCount cv_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册