未验证 提交 36ee6dd3 编写于 作者: L liutiexing 提交者: GitHub

Refine events waiter (#40876)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Add EventsWaiter

* update

* Revert "Add EventsWaiter"

This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2.

* update

* update Error MSG

* update EventsWaiter

* update
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 2e8f9882
......@@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);
exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();
async_work_queue_.reset(nullptr);
}
......
......@@ -19,37 +19,79 @@
namespace paddle {
namespace framework {
constexpr EventsWaiter::EventId kEmptyEventId = 0;
EventsWaiter::EventsWaiter()
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {}
: trigger_event_(kEmptyEventId),
counter_(0),
eof_(true),
waiting_(false),
cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::LevelTriggered;
evt->checker = std::move(checker);
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name) {
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::EdgeTriggered;
evt->checker = []() { return false; };
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
void EventsWaiter::UnregisterEvent(const EventId& id) {
VLOG(10) << "Unregister event id:" << id;
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_.erase(id);
deleted_events_.insert(id);
if (deleted_events_.size() == events_.size()) {
eof_.store(true, std::memory_order_relaxed);
}
}
if (eof_.load(std::memory_order_relaxed)) {
cv_.Notify(true);
}
}
std::string EventsWaiter::WaitEvent() {
......@@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
auto w = cv_.GetWaiter(0);
EventId triggered = trigger_event_;
while (triggered == kEmptyEventId && !eof_) {
cv_.Prewait();
std::string* triggered = trigger_event_;
if (triggered == nullptr) {
// double check
triggered = trigger_event_;
// checkers
if (triggered == kEmptyEventId) {
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = new std::string(evt.name);
triggered = evt.id;
break;
}
}
}
if (triggered != nullptr) {
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, triggered,
std::memory_order_seq_cst,
if (triggered != kEmptyEventId) {
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, triggered, std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete triggered;
triggered = prev;
}
}
}
if (triggered) {
if (triggered != kEmptyEventId || eof_) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
waiting_.store(false);
auto trigger_event = *triggered;
delete triggered;
return trigger_event;
}
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false, std::memory_order_relaxed);
std::string evt_name =
triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) {
events_.erase(evt);
}
deleted_events_.clear();
}
}
return evt_name;
}
int EventsWaiter::Clear() {
......@@ -106,32 +166,33 @@ int EventsWaiter::Clear() {
std::memory_order_relaxed)) {
return -1;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return 0;
}
void EventsWaiter::TriggerEvent(const EventId& id) {
VLOG(10) << "Try to trigger event id:" << id;
std::string* trigger_event = new std::string;
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto iter = events_.find(id);
if (iter == events_.end()) {
delete trigger_event;
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) {
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
*trigger_event = iter->second.name;
}
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, trigger_event,
VLOG(10) << "Triggered event id:" << id;
cv_.Notify(true);
}
void EventsWaiter::CancelEvent(const EventId& id) {
VLOG(10) << "Try to cancel event id:" << id;
EventId prev = id;
if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete trigger_event;
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event;
cv_.Notify(true);
VLOG(10) << "Cancelled event id:" << id;
}
std::string EventsWaiter::GetEventName(const EventId& id) {
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
......@@ -37,13 +38,12 @@ class EventsWaiter {
// Make sure EventsWaiter has a longer lifetime than EventNotifier.
class EventNotifier {
public:
void NotifyEvent() { waiter_.TriggerEvent(id_); }
~EventNotifier() { waiter_.UnregisterEvent(id_); }
void UnregisterEvent() { waiter_.UnregisterEvent(id_); }
void NotifyEvent() { waiter_.TriggerEvent(id_); }
EventId GetEventId() { return id_; }
void CancelEvent() { waiter_.CancelEvent(id_); }
// return "Unregistered" if the corresponding event was unregistered.
std::string GetEventName() { return waiter_.GetEventName(id_); }
private:
......@@ -97,12 +97,16 @@ class EventsWaiter {
void TriggerEvent(const EventId& id);
void CancelEvent(const EventId& id);
std::string GetEventName(const EventId& id);
std::unordered_map<EventId, EventInfo> events_;
std::unordered_set<EventId> deleted_events_;
paddle::memory::SpinLock events_lock_;
std::atomic<std::string*> trigger_event_;
std::atomic<EventId> trigger_event_;
std::atomic<uint64_t> counter_;
std::atomic<bool> eof_;
std::atomic<bool> waiting_;
EventCount cv_;
};
......
......@@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue {
public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options_.detached == false && options.events_waiter != nullptr) {
......@@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue {
}
virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
......@@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue {
}
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}
......@@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr &&
options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options.detached == false && options.events_waiter != nullptr) {
if (options.detached == false && options.events_waiter != nullptr &&
!destruct_notifier_) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
......@@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
}
WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
......@@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
free(queues_storage_);
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include <atomic>
#include <thread>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
......@@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) {
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
EXPECT_EQ(notifier->GetEventName(), "test_register_lt");
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
notifier->UnregisterEvent();
EXPECT_EQ(notifier->GetEventName(), "Unregistered");
notifier.reset();
notifier = events_waiter.RegisterEvent("test_register_et");
notifier->NotifyEvent();
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et");
notifier->NotifyEvent();
notifier->CancelEvent();
}
TEST(WorkQueue, TestSingleThreadedWorkQueue) {
......@@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel
work_queue->Cancel();
// Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
});
work_queue.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
waiter_thread.join();
}
TEST(WorkQueue, TestWorkQueueGroup) {
......@@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) {
// WaitQueueGroupEmpty
events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
EXPECT_EQ(handle.get(), random_num);
// Cancel
queue_group->Cancel();
events_waiter.WaitEvent();
// Wait kQueueDestructEvent
std::thread waiter_thread([&events_waiter]() {
EXPECT_EQ(events_waiter.WaitEvent(),
paddle::framework::kQueueDestructEvent);
EXPECT_EQ(events_waiter.WaitEvent(), "NoEventNotifier");
});
queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
EXPECT_EQ(handle.get(), random_num);
waiter_thread.join();
}
......@@ -81,7 +81,13 @@ class TaskTracker {
~TaskTracker() = default;
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }
void AddCounter() {
if (0 == num_tasks_.fetch_add(1, std::memory_order_relaxed)) {
if (notifier_ != nullptr) {
notifier_->CancelEvent();
}
}
}
void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册