未验证 提交 36ee6dd3 编写于 作者: L liutiexing 提交者: GitHub

Refine events waiter (#40876)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Add EventsWaiter

* update

* Revert "Add EventsWaiter"

This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2.

* update

* update Error MSG

* update EventsWaiter

* update
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 2e8f9882
...@@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() { ...@@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread // cancle gc's thread
gc_.reset(nullptr); gc_.reset(nullptr);
exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();
async_work_queue_.reset(nullptr); async_work_queue_.reset(nullptr);
} }
......
...@@ -19,37 +19,79 @@ ...@@ -19,37 +19,79 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
constexpr EventsWaiter::EventId kEmptyEventId = 0;
EventsWaiter::EventsWaiter() EventsWaiter::EventsWaiter()
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {} : trigger_event_(kEmptyEventId),
counter_(0),
eof_(true),
waiting_(false),
cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent( std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) { const std::string& name, EventChecker checker) {
auto counter = counter_.fetch_add(1); EventId id = kEmptyEventId;
auto id = std::hash<std::string>()(name + std::to_string(counter)); EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::LevelTriggered;
evt->checker = std::move(checker);
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name; VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this)); auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier; return notifier;
} }
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent( std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name) { const std::string& name) {
auto counter = counter_.fetch_add(1); EventId id = kEmptyEventId;
auto id = std::hash<std::string>()(name + std::to_string(counter)); EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::EdgeTriggered;
evt->checker = []() { return false; };
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name; VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this)); auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier; return notifier;
} }
void EventsWaiter::UnregisterEvent(const EventId& id) { void EventsWaiter::UnregisterEvent(const EventId& id) {
VLOG(10) << "Unregister event id:" << id; VLOG(10) << "Unregister event id:" << id;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); {
events_.erase(id); std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
deleted_events_.insert(id);
if (deleted_events_.size() == events_.size()) {
eof_.store(true, std::memory_order_relaxed);
}
}
if (eof_.load(std::memory_order_relaxed)) {
cv_.Notify(true);
}
} }
std::string EventsWaiter::WaitEvent() { std::string EventsWaiter::WaitEvent() {
...@@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() { ...@@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() {
PADDLE_THROW( PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting.")); platform::errors::ResourceExhausted("Another thread is waiting."));
} }
auto w = cv_.GetWaiter(0); auto w = cv_.GetWaiter(0);
cv_.Prewait(); EventId triggered = trigger_event_;
std::string* triggered = trigger_event_; while (triggered == kEmptyEventId && !eof_) {
if (triggered == nullptr) { cv_.Prewait();
// double check
triggered = trigger_event_;
// checkers // checkers
{ if (triggered == kEmptyEventId) {
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); {
for (auto& kv : events_) { std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto& evt = kv.second; for (auto& kv : events_) {
if (TriggerType::LevelTriggered == evt.type && evt.checker()) { auto& evt = kv.second;
triggered = new std::string(evt.name); if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
break; triggered = evt.id;
break;
}
} }
} }
} if (triggered != kEmptyEventId) {
if (triggered != nullptr) { EventId prev = kEmptyEventId;
std::string* prev = nullptr; if (!trigger_event_.compare_exchange_strong(
if (!trigger_event_.compare_exchange_strong(prev, triggered, prev, triggered, std::memory_order_seq_cst,
std::memory_order_seq_cst, std::memory_order_relaxed)) {
std::memory_order_relaxed)) { triggered = prev;
delete triggered; }
triggered = prev;
} }
} }
if (triggered != kEmptyEventId || eof_) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
}
} }
if (triggered) {
cv_.CancelWait(); trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
} else { waiting_.store(false, std::memory_order_relaxed);
cv_.CommitWait(w); std::string evt_name =
triggered = trigger_event_; triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) {
events_.erase(evt);
}
deleted_events_.clear();
}
} }
trigger_event_.store(nullptr, std::memory_order_relaxed); return evt_name;
waiting_.store(false);
auto trigger_event = *triggered;
delete triggered;
return trigger_event;
} }
int EventsWaiter::Clear() { int EventsWaiter::Clear() {
...@@ -106,32 +166,33 @@ int EventsWaiter::Clear() { ...@@ -106,32 +166,33 @@ int EventsWaiter::Clear() {
std::memory_order_relaxed)) { std::memory_order_relaxed)) {
return -1; return -1;
} }
trigger_event_.store(nullptr, std::memory_order_relaxed); trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false); waiting_.store(false);
return 0; return 0;
} }
void EventsWaiter::TriggerEvent(const EventId& id) { void EventsWaiter::TriggerEvent(const EventId& id) {
VLOG(10) << "Try to trigger event id:" << id; VLOG(10) << "Try to trigger event id:" << id;
std::string* trigger_event = new std::string; EventId prev = kEmptyEventId;
{ if (!trigger_event_.compare_exchange_strong(
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) {
auto iter = events_.find(id); VLOG(10) << "Event id:" << prev << " is pending";
if (iter == events_.end()) { return;
delete trigger_event;
return;
}
*trigger_event = iter->second.name;
} }
std::string* prev = nullptr; VLOG(10) << "Triggered event id:" << id;
if (!trigger_event_.compare_exchange_strong(prev, trigger_event, cv_.Notify(true);
}
void EventsWaiter::CancelEvent(const EventId& id) {
VLOG(10) << "Try to cancel event id:" << id;
EventId prev = id;
if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId,
std::memory_order_seq_cst, std::memory_order_seq_cst,
std::memory_order_relaxed)) { std::memory_order_relaxed)) {
delete trigger_event; VLOG(10) << "Event id:" << prev << " is pending";
return; return;
} }
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event; VLOG(10) << "Cancelled event id:" << id;
cv_.Notify(true);
} }
std::string EventsWaiter::GetEventName(const EventId& id) { std::string EventsWaiter::GetEventName(const EventId& id) {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h" #include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/memory/allocation/spin_lock.h"
...@@ -37,13 +38,12 @@ class EventsWaiter { ...@@ -37,13 +38,12 @@ class EventsWaiter {
// Make sure EventsWaiter has a longer lifetime than EventNotifier. // Make sure EventsWaiter has a longer lifetime than EventNotifier.
class EventNotifier { class EventNotifier {
public: public:
void NotifyEvent() { waiter_.TriggerEvent(id_); } ~EventNotifier() { waiter_.UnregisterEvent(id_); }
void UnregisterEvent() { waiter_.UnregisterEvent(id_); } void NotifyEvent() { waiter_.TriggerEvent(id_); }
EventId GetEventId() { return id_; } void CancelEvent() { waiter_.CancelEvent(id_); }
// return "Unregistered" if the corresponding event was unregistered.
std::string GetEventName() { return waiter_.GetEventName(id_); } std::string GetEventName() { return waiter_.GetEventName(id_); }
private: private:
...@@ -97,12 +97,16 @@ class EventsWaiter { ...@@ -97,12 +97,16 @@ class EventsWaiter {
void TriggerEvent(const EventId& id); void TriggerEvent(const EventId& id);
void CancelEvent(const EventId& id);
std::string GetEventName(const EventId& id); std::string GetEventName(const EventId& id);
std::unordered_map<EventId, EventInfo> events_; std::unordered_map<EventId, EventInfo> events_;
std::unordered_set<EventId> deleted_events_;
paddle::memory::SpinLock events_lock_; paddle::memory::SpinLock events_lock_;
std::atomic<std::string*> trigger_event_; std::atomic<EventId> trigger_event_;
std::atomic<uint64_t> counter_; std::atomic<uint64_t> counter_;
std::atomic<bool> eof_;
std::atomic<bool> waiting_; std::atomic<bool> waiting_;
EventCount cv_; EventCount cv_;
}; };
......
...@@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue { ...@@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue {
public: public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.events_waiter != nullptr) { if (options_.track_task && options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
} }
if (options_.detached == false && options.events_waiter != nullptr) { if (options_.detached == false && options.events_waiter != nullptr) {
...@@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue { ...@@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue {
} }
virtual ~WorkQueueImpl() { virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_; delete queue_;
if (tracker_ != nullptr) { if (tracker_ != nullptr) {
tracker_->~TaskTracker(); tracker_->~TaskTracker();
...@@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue { ...@@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue {
} }
if (destruct_notifier_) { if (destruct_notifier_) {
destruct_notifier_->NotifyEvent(); destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
} }
} }
...@@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
const auto& options = queues_options_[idx]; const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr && if (options.track_task && tracker_ == nullptr &&
options.events_waiter != nullptr) { options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
} }
if (options.detached == false && options.events_waiter != nullptr) { if (options.detached == false && options.events_waiter != nullptr &&
!destruct_notifier_) {
destruct_notifier_ = destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent); options.events_waiter->RegisterEvent(kQueueDestructEvent);
} }
...@@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
} }
WorkQueueGroupImpl::~WorkQueueGroupImpl() { WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) { for (auto queue : queues_) {
queue->~NonblockingThreadPool(); queue->~NonblockingThreadPool();
} }
...@@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { ...@@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
free(queues_storage_); free(queues_storage_);
if (destruct_notifier_) { if (destruct_notifier_) {
destruct_notifier_->NotifyEvent(); destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
} }
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include <atomic> #include <atomic>
#include <thread>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
...@@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) { ...@@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) {
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
EXPECT_EQ(notifier->GetEventName(), "test_register_lt"); EXPECT_EQ(notifier->GetEventName(), "test_register_lt");
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
notifier->UnregisterEvent(); notifier.reset();
EXPECT_EQ(notifier->GetEventName(), "Unregistered");
notifier = events_waiter.RegisterEvent("test_register_et"); notifier = events_waiter.RegisterEvent("test_register_et");
notifier->NotifyEvent(); notifier->NotifyEvent();
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et"); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et");
notifier->NotifyEvent();
notifier->CancelEvent();
} }
TEST(WorkQueue, TestSingleThreadedWorkQueue) { TEST(WorkQueue, TestSingleThreadedWorkQueue) {
...@@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel // Cancel
work_queue->Cancel(); work_queue->Cancel();
// Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
});
work_queue.reset(); work_queue.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); waiter_thread.join();
} }
TEST(WorkQueue, TestWorkQueueGroup) { TEST(WorkQueue, TestWorkQueueGroup) {
...@@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) {
// WaitQueueGroupEmpty // WaitQueueGroupEmpty
events_waiter.WaitEvent(); events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
EXPECT_EQ(handle.get(), random_num);
// Cancel // Cancel
queue_group->Cancel(); queue_group->Cancel();
events_waiter.WaitEvent(); // Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
EXPECT_EQ(events_waiter.WaitEvent(), "NoEventNotifier");
});
queue_group.reset(); queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); waiter_thread.join();
EXPECT_EQ(handle.get(), random_num);
} }
...@@ -81,7 +81,13 @@ class TaskTracker { ...@@ -81,7 +81,13 @@ class TaskTracker {
~TaskTracker() = default; ~TaskTracker() = default;
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } void AddCounter() {
if (0 == num_tasks_.fetch_add(1, std::memory_order_relaxed)) {
if (notifier_ != nullptr) {
notifier_->CancelEvent();
}
}
}
void SubCounter() { void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册