未验证 提交 63471c83 编写于 作者: L liutiexing 提交者: GitHub

refine AsyncWorkQueue (#40977)

上级 733d8168
...@@ -29,6 +29,8 @@ namespace paddle { ...@@ -29,6 +29,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
constexpr size_t kPrepareWorkQueueIdx = 2;
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) { std::function<void()> fn) {
// NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used. // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
...@@ -47,34 +49,31 @@ using VariableIdMap = std::map<std::string, std::vector<int>>; ...@@ -47,34 +49,31 @@ using VariableIdMap = std::map<std::string, std::vector<int>>;
void AsyncWorkQueue::PrepareAtomicDeps( void AsyncWorkQueue::PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) { const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps"; VLOG(4) << "PrepareAtomicDeps";
auto p = std::make_shared< atomic_deps_ =
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>(); queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] {
atomic_deps_ = p->get_future(); auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
queue_group_->AddTask(2, [&dependecy_count, p] { dependecy_count.size());
auto* op_deps = for (size_t i = 0; i < dependecy_count.size(); ++i) {
new std::vector<std::atomic<size_t>>(dependecy_count.size()); (*op_deps)[i] = dependecy_count[i];
for (size_t i = 0; i < dependecy_count.size(); ++i) { }
(*op_deps)[i] = dependecy_count[i]; VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
} return op_deps;
VLOG(4) << "AtomicDeps:" << op_deps << " " << (*op_deps).size(); });
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(op_deps));
});
} }
void AsyncWorkQueue::PrepareAtomicVarRef( void AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) { const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef"; VLOG(4) << "PrepareAtomicVarRef";
auto p = std::make_shared< atomic_var_ref_ =
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>(); queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
atomic_var_ref_ = p->get_future(); auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
queue_group_->AddTask(2, [&vec_meta_info, p] { vec_meta_info.size());
auto* var_ref = new std::vector<std::atomic<size_t>>(vec_meta_info.size()); for (size_t i = 0; i < vec_meta_info.size(); ++i) {
for (size_t i = 0; i < vec_meta_info.size(); ++i) { (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
(*var_ref)[i] = vec_meta_info[i].var_ref_count_; }
} VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
VLOG(4) << "AtomicVarRef:" << var_ref << " " << (*var_ref).size(); return var_ref;
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(var_ref)); });
});
} }
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <shared_mutex>
#include <thread> #include <thread>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
...@@ -60,16 +61,17 @@ class ThreadDataRegistry { ...@@ -60,16 +61,17 @@ class ThreadDataRegistry {
private: private:
// types // types
using LockType = std::shared_timed_mutex;
class ThreadDataHolder; class ThreadDataHolder;
class ThreadDataRegistryImpl { class ThreadDataRegistryImpl {
public: public:
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<LockType> lock(lock_);
tid_map_[tid] = tls_obj; tid_map_[tid] = tls_obj;
} }
void UnregisterData(uint64_t tid) { void UnregisterData(uint64_t tid) {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<LockType> lock(lock_);
tid_map_.erase(tid); tid_map_.erase(tid);
} }
...@@ -77,7 +79,7 @@ class ThreadDataRegistry { ...@@ -77,7 +79,7 @@ class ThreadDataRegistry {
std::is_copy_constructible<Alias>::value>> std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() { std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy; std::unordered_map<uint64_t, T> data_copy;
std::lock_guard<std::mutex> lock(lock_); std::shared_lock<LockType> lock(lock_);
data_copy.reserve(tid_map_.size()); data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) { for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData()); data_copy.emplace(kv.first, kv.second->GetData());
...@@ -88,7 +90,7 @@ class ThreadDataRegistry { ...@@ -88,7 +90,7 @@ class ThreadDataRegistry {
std::unordered_map<uint64_t, std::reference_wrapper<T>> std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() { GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref; std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref;
std::lock_guard<std::mutex> lock(lock_); std::shared_lock<LockType> lock(lock_);
data_ref.reserve(tid_map_.size()); data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) { for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData())); data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
...@@ -97,7 +99,7 @@ class ThreadDataRegistry { ...@@ -97,7 +99,7 @@ class ThreadDataRegistry {
} }
private: private:
std::mutex lock_; LockType lock_;
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册