diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 08f2dd3dc2487cc6827c19f9a3f148d511f684af..b9470cd3736d1e95d19e7414920bfb471212c542 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -29,6 +29,8 @@ namespace paddle { namespace framework { namespace interpreter { +constexpr size_t kPrepareWorkQueueIdx = 2; + void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, std::function fn) { // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used. @@ -47,34 +49,31 @@ using VariableIdMap = std::map>; void AsyncWorkQueue::PrepareAtomicDeps( const std::vector& dependecy_count) { VLOG(4) << "PrepareAtomicDeps"; - auto p = std::make_shared< - std::promise>>>>(); - atomic_deps_ = p->get_future(); - queue_group_->AddTask(2, [&dependecy_count, p] { - auto* op_deps = - new std::vector>(dependecy_count.size()); - for (size_t i = 0; i < dependecy_count.size(); ++i) { - (*op_deps)[i] = dependecy_count[i]; - } - VLOG(4) << "AtomicDeps:" << op_deps << " " << (*op_deps).size(); - p->set_value(std::unique_ptr>>(op_deps)); - }); + atomic_deps_ = + queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] { + auto op_deps = std::make_unique>>( + dependecy_count.size()); + for (size_t i = 0; i < dependecy_count.size(); ++i) { + (*op_deps)[i] = dependecy_count[i]; + } + VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size(); + return op_deps; + }); } void AsyncWorkQueue::PrepareAtomicVarRef( const std::vector& vec_meta_info) { VLOG(4) << "PrepareAtomicVarRef"; - auto p = std::make_shared< - std::promise>>>>(); - atomic_var_ref_ = p->get_future(); - queue_group_->AddTask(2, [&vec_meta_info, p] { - auto* var_ref = new std::vector>(vec_meta_info.size()); - for (size_t i = 0; i < vec_meta_info.size(); ++i) { - (*var_ref)[i] = vec_meta_info[i].var_ref_count_; - } - VLOG(4) << "AtomicVarRef:" << var_ref << " " << (*var_ref).size(); - p->set_value(std::unique_ptr>>(var_ref)); - }); + atomic_var_ref_ = + queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] { + auto var_ref = std::make_unique>>( + vec_meta_info.size()); + for (size_t i = 0; i < vec_meta_info.size(); ++i) { + (*var_ref)[i] = vec_meta_info[i].var_ref_count_; + } + VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size(); + return var_ref; + }); } bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { diff --git a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h index b46bf7a8d34821c43c4d455790c97486351147de..ffdddc39a31e39b0e5f16438ff52a6613fd7bdfe 100644 --- a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h +++ b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -60,16 +61,17 @@ class ThreadDataRegistry { private: // types + using LockType = std::shared_timed_mutex; class ThreadDataHolder; class ThreadDataRegistryImpl { public: void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); tid_map_[tid] = tls_obj; } void UnregisterData(uint64_t tid) { - std::lock_guard lock(lock_); + std::lock_guard lock(lock_); tid_map_.erase(tid); } @@ -77,7 +79,7 @@ class ThreadDataRegistry { std::is_copy_constructible::value>> std::unordered_map GetAllThreadDataByValue() { std::unordered_map data_copy; - std::lock_guard lock(lock_); + std::shared_lock lock(lock_); data_copy.reserve(tid_map_.size()); for (auto& kv : tid_map_) { data_copy.emplace(kv.first, kv.second->GetData()); @@ -88,7 +90,7 @@ class ThreadDataRegistry { std::unordered_map> GetAllThreadDataByRef() { std::unordered_map> data_ref; - std::lock_guard lock(lock_); + std::shared_lock lock(lock_); data_ref.reserve(tid_map_.size()); for (auto& kv : tid_map_) { data_ref.emplace(kv.first, std::ref(kv.second->GetData())); @@ -97,7 +99,7 @@ class ThreadDataRegistry { } private: - std::mutex lock_; + LockType lock_; std::unordered_map tid_map_; // not owned };