未验证 提交 198d11be 编写于 作者: L liutiexing 提交者: GitHub

Upgrade work queue (#38335)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2.

* update EventsWater

* fix

* split workqueue files

* add more tests

* fix

* bugfix

* bugfix

* update
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 4221cd33
...@@ -2,8 +2,9 @@ set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_f ...@@ -2,8 +2,9 @@ set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_f
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor nan_inf_utils) graph_to_program_pass variable_helper timer monitor nan_inf_utils)
add_subdirectory(workqueue)
cc_library(data_transfer SRCS data_transfer.cc DEPS enforce scope glog) cc_library(data_transfer SRCS data_transfer.cc DEPS enforce scope glog)
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope) cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper)
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs data_transfer) cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs data_transfer)
...@@ -11,7 +12,6 @@ cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog ne ...@@ -11,7 +12,6 @@ cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog ne
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs) cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager) cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
# skip win32 since wget is not installed by default on windows machine. # skip win32 since wget is not installed by default on windows machine.
# skip COVERAGE_CI since the test runs slowly because of instrumentation. # skip COVERAGE_CI since the test runs slowly because of instrumentation.
......
...@@ -48,8 +48,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -48,8 +48,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_)); new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_));
gc_.reset(new InterpreterCoreGarbageCollector()); gc_.reset(new InterpreterCoreGarbageCollector());
exception_notifier_ = main_thread_blocker_.RegisterEvent( exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
create_local_scope_ = FLAGS_new_executor_use_local_scope; create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) { if (FLAGS_new_executor_use_local_scope) {
......
...@@ -26,8 +26,6 @@ ...@@ -26,8 +26,6 @@
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h" #include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <queue> #include <queue>
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -61,12 +61,14 @@ class AsyncWorkQueue { ...@@ -61,12 +61,14 @@ class AsyncWorkQueue {
group_options.emplace_back(/*num_threads*/ host_num_threads, group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*track_task*/ true, /*track_task*/ true,
/*queue_empty_waiter*/ waiter); /*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel // for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1, group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*track_task*/ true, /*track_task*/ true,
/*queue_empty_waiter*/ waiter); /*detached*/ true,
/*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options); queue_group_ = CreateWorkQueueGroup(group_options);
} }
......
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc events_waiter.cc DEPS enforce glog)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
...@@ -41,6 +41,10 @@ ...@@ -41,6 +41,10 @@
// and won't block, or notifying thread will see state_ change and will unblock // and won't block, or notifying thread will see state_ change and will unblock
// the waiter, or both. But it can't happen that both threads don't see each // the waiter, or both. But it can't happen that both threads don't see each
// other changes, which would lead to deadlock. // other changes, which would lead to deadlock.
//
// What changed by PaddlePaddle
// 1. Allocate aligned storage for Waiters to get better performance.
// 2. Replace Eigen utils with std utils.
#pragma once #pragma once
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/workqueue/events_waiter.h"
#include <glog/logging.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
EventsWaiter::EventsWaiter()
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}
void EventsWaiter::UnregisterEvent(const EventId& id) {
VLOG(10) << "Unregister event id:" << id;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_.erase(id);
}
std::string EventsWaiter::WaitEvent() {
// only one user can wait at any time
bool waiting = false;
if (!waiting_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
auto w = cv_.GetWaiter(0);
cv_.Prewait();
std::string* triggered = trigger_event_;
if (triggered == nullptr) {
// checkers
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = new std::string(evt.name);
break;
}
}
}
if (triggered != nullptr) {
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, triggered,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete triggered;
triggered = prev;
}
}
}
if (triggered) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
waiting_.store(false);
auto trigger_event = *triggered;
delete triggered;
return trigger_event;
}
int EventsWaiter::Clear() {
bool waiting = false;
if (!waiting_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
return -1;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
waiting_.store(false);
return 0;
}
void EventsWaiter::TriggerEvent(const EventId& id) {
VLOG(10) << "Try to trigger event id:" << id;
std::string* trigger_event = new std::string;
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto iter = events_.find(id);
if (iter == events_.end()) {
delete trigger_event;
return;
}
*trigger_event = iter->second.name;
}
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, trigger_event,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete trigger_event;
return;
}
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event;
cv_.Notify(true);
}
std::string EventsWaiter::GetEventName(const EventId& id) {
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto iter = events_.find(id);
if (iter == events_.end()) {
return "Unregistered";
}
return iter->second.name;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <cstddef>
#include <functional>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
namespace paddle {
namespace framework {
// A multiplexing waiter, be able to wait multiple kinds of events
// simultaneously.
// Muti-Producer single-consumer single-slot message-queue.
class EventsWaiter {
public:
using EventId = std::size_t;
using EventChecker = std::function<bool()>;
// Make sure EventsWaiter has a longer lifetime than EventNotifier.
class EventNotifier {
public:
void NotifyEvent() { waiter_.TriggerEvent(id_); }
void UnregisterEvent() { waiter_.UnregisterEvent(id_); }
EventId GetEventId() { return id_; }
// return "Unregistered" if the corresponding event was unregistered.
std::string GetEventName() { return waiter_.GetEventName(id_); }
private:
friend EventsWaiter;
EventNotifier(EventId id, EventsWaiter* waiter)
: id_(id), waiter_(*waiter) {}
EventNotifier(const EventNotifier&) = delete;
void operator=(const EventNotifier&) = delete;
EventId id_;
EventsWaiter& waiter_;
};
EventsWaiter();
EventsWaiter(const EventsWaiter&) = delete;
EventsWaiter& operator=(const EventsWaiter&) = delete;
// Register a level-triggered event. If the checker returns true or
// EventNotifier::NotifyEvent is called, the corresponding event will be
// distributed.
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name,
EventChecker checker);
// Register an edge-triggered event. The corresponding event will be
// distributed when EventNotifier::NotifyEvent is called.
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name);
void UnregisterEvent(const EventId& id);
// Blocking the calling thread to wait any of the registered events.
std::string WaitEvent();
// Nonblocking.
// Clear the slot, no matter whether there is an event.
// Return value:
// -1 : another thread is waiting.
// 0 : succ.
int Clear();
private:
friend EventNotifier;
enum class TriggerType { LevelTriggered, EdgeTriggered };
struct EventInfo {
EventId id;
std::string name;
TriggerType type;
EventChecker checker;
};
void TriggerEvent(const EventId& id);
std::string GetEventName(const EventId& id);
std::unordered_map<EventId, EventInfo> events_;
paddle::memory::SpinLock events_lock_;
std::atomic<std::string*> trigger_event_;
std::atomic<uint64_t> counter_;
std::atomic<bool> waiting_;
EventCount cv_;
};
} // namespace framework
} // namespace paddle
...@@ -12,43 +12,14 @@ ...@@ -12,43 +12,14 @@
#include <atomic> #include <atomic>
#include <cstdlib> #include <cstdlib>
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/event_count.h" #include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/framework/new_executor/run_queue.h" #include "paddle/fluid/framework/new_executor/workqueue/run_queue.h"
#include "paddle/fluid/framework/new_executor/thread_environment.h" #include "paddle/fluid/framework/new_executor/workqueue/thread_environment.h"
#include "paddle/fluid/platform/os_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename Notifier>
class TaskTracker {
public:
TaskTracker() = default;
explicit TaskTracker(Notifier& notifier) : notifier_(&notifier) {}
TaskTracker(const TaskTracker&) = delete;
TaskTracker& operator=(const TaskTracker&) = delete;
~TaskTracker() = default;
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }
void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
if (notifier_ != nullptr) {
notifier_->NotifyEvent();
}
}
}
uint64_t PendingTaskNum() { return num_tasks_.load(); }
private:
alignas(64) std::atomic<uint64_t> num_tasks_{0};
Notifier* notifier_{nullptr};
};
template <typename Environment> template <typename Environment>
class ThreadPoolTempl { class ThreadPoolTempl {
public: public:
......
...@@ -29,6 +29,11 @@ ...@@ -29,6 +29,11 @@
// separate state variable as null/non-null pointer value would serve as state, // separate state variable as null/non-null pointer value would serve as state,
// but that would require malloc/free per operation for large, complex values // but that would require malloc/free per operation for large, complex values
// (and this is designed to store std::function<()>). // (and this is designed to store std::function<()>).
//
// What changed by PaddlePaddle
// 1. Use paddle::memory::SpinLock instead of std::mutex to protect back_.
// 2. Make front_/back_ aligned to get better performance.
// 3. Replace Eigen utils with std utils.
#pragma once #pragma once
...@@ -37,7 +42,7 @@ ...@@ -37,7 +42,7 @@
#include <cstdint> #include <cstdint>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/memory/allocation/spin_lock.h"
namespace paddle { namespace paddle {
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" #include "paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -18,24 +18,35 @@ using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>; ...@@ -18,24 +18,35 @@ using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>;
class WorkQueueImpl : public WorkQueue { class WorkQueueImpl : public WorkQueue {
public: public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.queue_empty_waiter != nullptr) { if (options_.track_task && options.events_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage); TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent( empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent, kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; }); [tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get()); tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options_.detached == false && options.events_waiter != nullptr) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
} }
queue_ = new NonblockingThreadPool(options_.num_threads, queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning); options_.allow_spinning);
} }
virtual ~WorkQueueImpl() { virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_;
if (tracker_ != nullptr) { if (tracker_ != nullptr) {
tracker_->~TaskTracker(); tracker_->~TaskTracker();
AlignedFree(tracker_); AlignedFree(tracker_);
} }
delete queue_; if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
} }
void AddTask(std::function<void()> fn) override { void AddTask(std::function<void()> fn) override {
...@@ -59,6 +70,8 @@ class WorkQueueImpl : public WorkQueue { ...@@ -59,6 +70,8 @@ class WorkQueueImpl : public WorkQueue {
private: private:
NonblockingThreadPool* queue_{nullptr}; NonblockingThreadPool* queue_{nullptr};
TaskTracker* tracker_{nullptr}; TaskTracker* tracker_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> empty_notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> destruct_notifier_;
}; };
class WorkQueueGroupImpl : public WorkQueueGroup { class WorkQueueGroupImpl : public WorkQueueGroup {
...@@ -80,6 +93,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { ...@@ -80,6 +93,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup {
std::vector<NonblockingThreadPool*> queues_; std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_; NonblockingThreadPool* queues_storage_;
TaskTracker* tracker_; TaskTracker* tracker_;
std::shared_ptr<EventsWaiter::EventNotifier> empty_notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> destruct_notifier_;
}; };
WorkQueueGroupImpl::WorkQueueGroupImpl( WorkQueueGroupImpl::WorkQueueGroupImpl(
...@@ -94,13 +109,17 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -94,13 +109,17 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
for (size_t idx = 0; idx < num_queues; ++idx) { for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx]; const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr && if (options.track_task && tracker_ == nullptr &&
options.queue_empty_waiter != nullptr) { options.events_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage); TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
auto notifier = options.queue_empty_waiter->RegisterEvent( empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent, kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; }); [tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier.get()); tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options.detached == false && options.events_waiter != nullptr) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
} }
queues_[idx] = new (&queues_storage_[idx]) queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning); NonblockingThreadPool(options.num_threads, options.allow_spinning);
...@@ -108,6 +127,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -108,6 +127,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
} }
WorkQueueGroupImpl::~WorkQueueGroupImpl() { WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) { for (auto queue : queues_) {
queue->~NonblockingThreadPool(); queue->~NonblockingThreadPool();
} }
...@@ -116,6 +138,10 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { ...@@ -116,6 +138,10 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
AlignedFree(tracker_); AlignedFree(tracker_);
} }
free(queues_storage_); free(queues_storage_);
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
} }
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) { void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
......
...@@ -22,6 +22,7 @@ namespace paddle { ...@@ -22,6 +22,7 @@ namespace paddle {
namespace framework { namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty"; constexpr const char* kQueueEmptyEvent = "QueueEmpty";
constexpr const char* kQueueDestructEvent = "QueueDestruct";
class EventsWaiter; class EventsWaiter;
...@@ -32,20 +33,24 @@ struct WorkQueueOptions { ...@@ -32,20 +33,24 @@ struct WorkQueueOptions {
track_task(track_task) {} track_task(track_task) {}
WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task, WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task,
EventsWaiter* waiter) bool detached, EventsWaiter* waiter)
: num_threads(num_threads), : num_threads(num_threads),
allow_spinning(allow_spinning), allow_spinning(allow_spinning),
track_task(track_task), track_task(track_task),
queue_empty_waiter(waiter) {} detached(detached),
events_waiter(waiter) {}
size_t num_threads; size_t num_threads;
bool allow_spinning; bool allow_spinning;
// If you need to blocking the calling thread to wait "queue empty", set // If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will // track_task = true and set events_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty") // block the calling thread until any of events (including "queue empty")
// occured. // occured.
bool track_task; bool track_task;
EventsWaiter* queue_empty_waiter{nullptr}; // not owned // If you need to be noticed when a WorkQueue Destruct() , set detached =
// false and set events_waiter.
bool detached{true};
EventsWaiter* events_waiter{nullptr}; // not owned
}; };
class WorkQueue { class WorkQueue {
......
...@@ -12,11 +12,26 @@ ...@@ -12,11 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
#include <atomic> #include <atomic>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
TEST(WorkQueueUtils, TestEventsWaiter) {
using paddle::framework::EventsWaiter;
EventsWaiter events_waiter;
auto notifier =
events_waiter.RegisterEvent("test_register_lt", []() { return true; });
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
EXPECT_EQ(notifier->GetEventName(), "test_register_lt");
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt");
notifier->UnregisterEvent();
EXPECT_EQ(notifier->GetEventName(), "Unregistered");
notifier = events_waiter.RegisterEvent("test_register_et");
notifier->NotifyEvent();
EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et");
}
TEST(WorkQueue, TestSingleThreadedWorkQueue) { TEST(WorkQueue, TestSingleThreadedWorkQueue) {
VLOG(1) << "In Test"; VLOG(1) << "In Test";
...@@ -30,7 +45,8 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -30,7 +45,8 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
// CreateSingleThreadedWorkQueue // CreateSingleThreadedWorkQueue
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true, WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter); /*track_task*/ true, /*detached*/ true,
&events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options); auto work_queue = CreateSingleThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 1u); EXPECT_EQ(work_queue->NumThreads(), 1u);
...@@ -63,7 +79,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -63,7 +79,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
// CreateMultiThreadedWorkQueue // CreateMultiThreadedWorkQueue
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true, WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter); /*track_task*/ true, /*detached*/ false,
&events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options); auto work_queue = CreateMultiThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 10u); EXPECT_EQ(work_queue->NumThreads(), 10u);
...@@ -80,11 +97,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -80,11 +97,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
} }
// WaitQueueEmpty // WaitQueueEmpty
EXPECT_EQ(finished.load(), false); EXPECT_EQ(finished.load(), false);
events_waiter.WaitEvent(); EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueEmptyEvent);
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel // Cancel
work_queue->Cancel(); work_queue->Cancel();
work_queue.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
} }
TEST(WorkQueue, TestWorkQueueGroup) { TEST(WorkQueue, TestWorkQueueGroup) {
...@@ -99,9 +118,11 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -99,9 +118,11 @@ TEST(WorkQueue, TestWorkQueueGroup) {
// ThreadedWorkQueueGroup // ThreadedWorkQueueGroup
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true, WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter); /*track_task*/ true, /*detached*/ false,
&events_waiter);
WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true, WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter); /*track_task*/ true, /*detached*/ false,
&events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
// NumThreads // NumThreads
EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); EXPECT_EQ(queue_group->QueueNumThreads(0), 1u);
...@@ -126,4 +147,7 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -126,4 +147,7 @@ TEST(WorkQueue, TestWorkQueueGroup) {
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
// Cancel // Cancel
queue_group->Cancel(); queue_group->Cancel();
events_waiter.WaitEvent();
queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
} }
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
...@@ -55,62 +55,5 @@ void AlignedFree(void* mem_ptr) { ...@@ -55,62 +55,5 @@ void AlignedFree(void* mem_ptr) {
#endif #endif
} }
constexpr EventsWaiter::EventId kEmptyEventId = -1;
EventsWaiter::EventsWaiter()
: trigger_event_(kEmptyEventId), waiting_(false), cv_(1) {}
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
names_.emplace_back(name);
checkers_.emplace_back(std::move(checker));
EventId id = checkers_.size() - 1;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
notifiers_.emplace_back(notifier);
return notifier;
}
std::string EventsWaiter::WaitEvent() {
// only one user can wait at any time
bool waiting = false;
if (!waiting_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}
EventId id = kEmptyEventId;
auto w = cv_.GetWaiter(0);
cv_.Prewait();
int64_t event_num = checkers_.size();
for (int64_t i = 0; id == kEmptyEventId && i < event_num; ++i) {
if (checkers_[i]()) {
id = i;
}
}
if (id != kEmptyEventId) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
id = trigger_event_.load(std::memory_order_relaxed);
}
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return names_.at(id);
}
void EventsWaiter::SetTriggerEvent(const EventId& id) {
trigger_event_.store(id, std::memory_order_relaxed);
cv_.Notify(true);
}
std::string EventsWaiter::EventNotifier::GetEventName() {
return waiter_.names_.at(id_);
}
void EventsWaiter::EventNotifier::NotifyEvent() {
waiter_.SetTriggerEvent(id_);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include "paddle/fluid/framework/new_executor/workqueue/events_waiter.h"
#include "paddle/fluid/framework/new_executor/event_count.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -69,55 +68,34 @@ void* AlignedMalloc(size_t size, size_t alignment); ...@@ -69,55 +68,34 @@ void* AlignedMalloc(size_t size, size_t alignment);
void AlignedFree(void* memory_ptr); void AlignedFree(void* memory_ptr);
// A multiplexing waiter, be able to wait multi events simultaneously. template <typename Notifier>
// Blocking the calling thread to wait any of the registered events. class TaskTracker {
// Non-thread-safe.
class EventsWaiter {
public: public:
using EventId = int64_t; TaskTracker() = default;
using EventChecker = std::function<bool()>; explicit TaskTracker(Notifier& notifier) : notifier_(&notifier) {}
class EventNotifier { TaskTracker(const TaskTracker&) = delete;
public:
void NotifyEvent();
EventId GetEventId() { return id_; } TaskTracker& operator=(const TaskTracker&) = delete;
std::string GetEventName(); ~TaskTracker() = default;
private: void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }
friend EventsWaiter;
EventNotifier(EventId id, EventsWaiter* waiter)
: id_(id), waiter_(*waiter) {}
EventId id_; void SubCounter() {
EventsWaiter& waiter_; if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
}; if (notifier_ != nullptr) {
notifier_->NotifyEvent();
EventsWaiter(); }
}
EventsWaiter(const EventsWaiter&) = delete; }
EventsWaiter& operator=(const EventsWaiter&) = delete;
// All the RegisterEvent functions must be called before any WaitEvent
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name,
EventChecker checker);
// Wait any of the registered events uint64_t PendingTaskNum() { return num_tasks_.load(); }
std::string WaitEvent();
private: private:
friend EventNotifier; alignas(64) std::atomic<uint64_t> num_tasks_{0};
void SetTriggerEvent(const EventId& id); Notifier* notifier_{nullptr};
std::vector<std::string> names_;
std::vector<EventChecker> checkers_;
std::vector<std::shared_ptr<EventNotifier>> notifiers_;
std::atomic<EventId> trigger_event_;
std::atomic<bool> waiting_;
EventCount cv_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册