diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 25cb15d2cc8c27e5fa1477e60e4428d5823495dd..6e73aaef15e07d3f75bb463b9fcaa8a8fde5c834 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() { // cancle gc's thread gc_.reset(nullptr); - exception_notifier_->UnregisterEvent(); - completion_notifier_->UnregisterEvent(); - async_work_queue_.reset(nullptr); } diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc index ac45e7b5fdfe9feb284a0a5e156e6aacbc43f48b..163050ae5a65a7758fdcbe00f5e880acf4262f9d 100644 --- a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc @@ -19,37 +19,79 @@ namespace paddle { namespace framework { +constexpr EventsWaiter::EventId kEmptyEventId = 0; + EventsWaiter::EventsWaiter() - : trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {} + : trigger_event_(kEmptyEventId), + counter_(0), + eof_(true), + waiting_(false), + cv_(1) {} std::shared_ptr EventsWaiter::RegisterEvent( const std::string& name, EventChecker checker) { - auto counter = counter_.fetch_add(1); - auto id = std::hash()(name + std::to_string(counter)); + EventId id = kEmptyEventId; + EventInfo* evt = nullptr; + do { + auto counter = counter_.fetch_add(1); + id = std::hash()(name + std::to_string(counter)); + if (id == kEmptyEventId) { + continue; + } + std::lock_guard guard(events_lock_); + if (events_.count(id) > 0) { + continue; + } + evt = &(events_[id]); + } while (evt == nullptr); + evt->id = id; + evt->name = name; + evt->type = TriggerType::LevelTriggered; + evt->checker = std::move(checker); + eof_.store(false, std::memory_order_relaxed); VLOG(10) << "Register event id:" << id << " name:" << name; auto notifier = std::shared_ptr(new EventNotifier(id, this)); - EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)}; - std::lock_guard guard(events_lock_); - events_[id] = std::move(evt); return notifier; } std::shared_ptr EventsWaiter::RegisterEvent( const std::string& name) { - auto counter = counter_.fetch_add(1); - auto id = std::hash()(name + std::to_string(counter)); + EventId id = kEmptyEventId; + EventInfo* evt = nullptr; + do { + auto counter = counter_.fetch_add(1); + id = std::hash()(name + std::to_string(counter)); + if (id == kEmptyEventId) { + continue; + } + std::lock_guard guard(events_lock_); + if (events_.count(id) > 0) { + continue; + } + evt = &(events_[id]); + } while (evt == nullptr); + evt->id = id; + evt->name = name; + evt->type = TriggerType::EdgeTriggered; + evt->checker = []() { return false; }; + eof_.store(false, std::memory_order_relaxed); VLOG(10) << "Register event id:" << id << " name:" << name; auto notifier = std::shared_ptr(new EventNotifier(id, this)); - EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }}; - std::lock_guard guard(events_lock_); - events_[id] = std::move(evt); return notifier; } void EventsWaiter::UnregisterEvent(const EventId& id) { VLOG(10) << "Unregister event id:" << id; - std::lock_guard guard(events_lock_); - events_.erase(id); + { + std::lock_guard guard(events_lock_); + deleted_events_.insert(id); + if (deleted_events_.size() == events_.size()) { + eof_.store(true, std::memory_order_relaxed); + } + } + if (eof_.load(std::memory_order_relaxed)) { + cv_.Notify(true); + } } std::string EventsWaiter::WaitEvent() { @@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() { PADDLE_THROW( platform::errors::ResourceExhausted("Another thread is waiting.")); } + auto w = cv_.GetWaiter(0); - cv_.Prewait(); - std::string* triggered = trigger_event_; - if (triggered == nullptr) { + EventId triggered = trigger_event_; + while (triggered == kEmptyEventId && !eof_) { + cv_.Prewait(); + + // double check + triggered = trigger_event_; // checkers - { - std::lock_guard guard(events_lock_); - for (auto& kv : events_) { - auto& evt = kv.second; - if (TriggerType::LevelTriggered == evt.type && evt.checker()) { - triggered = new std::string(evt.name); - break; + if (triggered == kEmptyEventId) { + { + std::lock_guard guard(events_lock_); + for (auto& kv : events_) { + auto& evt = kv.second; + if (TriggerType::LevelTriggered == evt.type && evt.checker()) { + triggered = evt.id; + break; + } } } - } - if (triggered != nullptr) { - std::string* prev = nullptr; - if (!trigger_event_.compare_exchange_strong(prev, triggered, - std::memory_order_seq_cst, - std::memory_order_relaxed)) { - delete triggered; - triggered = prev; + if (triggered != kEmptyEventId) { + EventId prev = kEmptyEventId; + if (!trigger_event_.compare_exchange_strong( + prev, triggered, std::memory_order_seq_cst, + std::memory_order_relaxed)) { + triggered = prev; + } } } + + if (triggered != kEmptyEventId || eof_) { + cv_.CancelWait(); + } else { + cv_.CommitWait(w); + triggered = trigger_event_; + } } - if (triggered) { - cv_.CancelWait(); - } else { - cv_.CommitWait(w); - triggered = trigger_event_; + + trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); + waiting_.store(false, std::memory_order_relaxed); + std::string evt_name = + triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered); + VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name; + // lazy deletion + { + std::lock_guard guard(events_lock_); + if (deleted_events_.size() > 0) { + for (auto evt : deleted_events_) { + events_.erase(evt); + } + deleted_events_.clear(); + } } - trigger_event_.store(nullptr, std::memory_order_relaxed); - waiting_.store(false); - auto trigger_event = *triggered; - delete triggered; - return trigger_event; + return evt_name; } int EventsWaiter::Clear() { @@ -106,32 +166,33 @@ int EventsWaiter::Clear() { std::memory_order_relaxed)) { return -1; } - trigger_event_.store(nullptr, std::memory_order_relaxed); + trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); waiting_.store(false); return 0; } void EventsWaiter::TriggerEvent(const EventId& id) { VLOG(10) << "Try to trigger event id:" << id; - std::string* trigger_event = new std::string; - { - std::lock_guard guard(events_lock_); - auto iter = events_.find(id); - if (iter == events_.end()) { - delete trigger_event; - return; - } - *trigger_event = iter->second.name; + EventId prev = kEmptyEventId; + if (!trigger_event_.compare_exchange_strong( + prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) { + VLOG(10) << "Event id:" << prev << " is pending"; + return; } - std::string* prev = nullptr; - if (!trigger_event_.compare_exchange_strong(prev, trigger_event, + VLOG(10) << "Triggered event id:" << id; + cv_.Notify(true); +} + +void EventsWaiter::CancelEvent(const EventId& id) { + VLOG(10) << "Try to cancel event id:" << id; + EventId prev = id; + if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId, std::memory_order_seq_cst, std::memory_order_relaxed)) { - delete trigger_event; + VLOG(10) << "Event id:" << prev << " is pending"; return; } - VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event; - cv_.Notify(true); + VLOG(10) << "Cancelled event id:" << id; } std::string EventsWaiter::GetEventName(const EventId& id) { diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.h b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h index 5ffed15155d592941c77a846b9df563b81d70c66..9d85f4a27242c9f9c8ed7ffa80879d626527dd35 100644 --- a/paddle/fluid/framework/new_executor/workqueue/events_waiter.h +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "paddle/fluid/framework/new_executor/workqueue/event_count.h" #include "paddle/fluid/memory/allocation/spin_lock.h" @@ -37,13 +38,12 @@ class EventsWaiter { // Make sure EventsWaiter has a longer lifetime than EventNotifier. class EventNotifier { public: - void NotifyEvent() { waiter_.TriggerEvent(id_); } + ~EventNotifier() { waiter_.UnregisterEvent(id_); } - void UnregisterEvent() { waiter_.UnregisterEvent(id_); } + void NotifyEvent() { waiter_.TriggerEvent(id_); } - EventId GetEventId() { return id_; } + void CancelEvent() { waiter_.CancelEvent(id_); } - // return "Unregistered" if the corresponding event was unregistered. std::string GetEventName() { return waiter_.GetEventName(id_); } private: @@ -97,12 +97,16 @@ class EventsWaiter { void TriggerEvent(const EventId& id); + void CancelEvent(const EventId& id); + std::string GetEventName(const EventId& id); std::unordered_map events_; + std::unordered_set deleted_events_; paddle::memory::SpinLock events_lock_; - std::atomic trigger_event_; + std::atomic trigger_event_; std::atomic counter_; + std::atomic eof_; std::atomic waiting_; EventCount cv_; }; diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc index 596ffb9bfc0c4f624aeaf5874bdf18563d96d14c..881878ebb12a721e7b194036b9d36a89c5404365 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc @@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue { public: explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { if (options_.track_task && options.events_waiter != nullptr) { + empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - TaskTracker* tracker = reinterpret_cast(storage); - empty_notifier_ = options.events_waiter->RegisterEvent( - kQueueEmptyEvent, - [tracker]() { return tracker->PendingTaskNum() == 0; }); tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); } if (options_.detached == false && options.events_waiter != nullptr) { @@ -47,9 +44,6 @@ class WorkQueueImpl : public WorkQueue { } virtual ~WorkQueueImpl() { - if (empty_notifier_) { - empty_notifier_->UnregisterEvent(); - } delete queue_; if (tracker_ != nullptr) { tracker_->~TaskTracker(); @@ -57,7 +51,6 @@ class WorkQueueImpl : public WorkQueue { } if (destruct_notifier_) { destruct_notifier_->NotifyEvent(); - destruct_notifier_->UnregisterEvent(); } } @@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( const auto& options = queues_options_[idx]; if (options.track_task && tracker_ == nullptr && options.events_waiter != nullptr) { + empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); - TaskTracker* tracker = reinterpret_cast(storage); - empty_notifier_ = options.events_waiter->RegisterEvent( - kQueueEmptyEvent, - [tracker]() { return tracker->PendingTaskNum() == 0; }); tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); } - if (options.detached == false && options.events_waiter != nullptr) { + if (options.detached == false && options.events_waiter != nullptr && + !destruct_notifier_) { destruct_notifier_ = options.events_waiter->RegisterEvent(kQueueDestructEvent); } @@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( } WorkQueueGroupImpl::~WorkQueueGroupImpl() { - if (empty_notifier_) { - empty_notifier_->UnregisterEvent(); - } for (auto queue : queues_) { queue->~NonblockingThreadPool(); } @@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { free(queues_storage_); if (destruct_notifier_) { destruct_notifier_->NotifyEvent(); - destruct_notifier_->UnregisterEvent(); } } diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc index 97f0282a15837e74e874202cd1891ff62de8d951..d8e09fb6baefe4cfc0e40cf0a1985f98853b9da5 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include +#include #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" @@ -26,11 +27,12 @@ TEST(WorkQueueUtils, TestEventsWaiter) { EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); EXPECT_EQ(notifier->GetEventName(), "test_register_lt"); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); - notifier->UnregisterEvent(); - EXPECT_EQ(notifier->GetEventName(), "Unregistered"); + notifier.reset(); notifier = events_waiter.RegisterEvent("test_register_et"); notifier->NotifyEvent(); EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et"); + notifier->NotifyEvent(); + notifier->CancelEvent(); } TEST(WorkQueue, TestSingleThreadedWorkQueue) { @@ -106,8 +108,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); // Cancel work_queue->Cancel(); + // Wait kQueueDestructEvent + std::thread waiter_thread([&events_waiter]() { + EXPECT_EQ(events_waiter.WaitEvent(), + paddle::framework::kQueueDestructEvent); + }); work_queue.reset(); - EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); + waiter_thread.join(); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -154,10 +161,15 @@ TEST(WorkQueue, TestWorkQueueGroup) { // WaitQueueGroupEmpty events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); + EXPECT_EQ(handle.get(), random_num); // Cancel queue_group->Cancel(); - events_waiter.WaitEvent(); + // Wait kQueueDestructEvent + std::thread waiter_thread([&events_waiter]() { + EXPECT_EQ(events_waiter.WaitEvent(), + paddle::framework::kQueueDestructEvent); + EXPECT_EQ(events_waiter.WaitEvent(), "NoEventNotifier"); + }); queue_group.reset(); - EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); - EXPECT_EQ(handle.get(), random_num); + waiter_thread.join(); } diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h index eee64df285dcb0aed23a8d4a4c622639cfe0772a..b6e6ede8c334fa58b6bacec9876a287a5bd0b3e0 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h @@ -81,7 +81,13 @@ class TaskTracker { ~TaskTracker() = default; - void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } + void AddCounter() { + if (0 == num_tasks_.fetch_add(1, std::memory_order_relaxed)) { + if (notifier_ != nullptr) { + notifier_->CancelEvent(); + } + } + } void SubCounter() { if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {