未验证 提交 63471c83 编写于 作者: L liutiexing 提交者: GitHub

refine AsyncWorkQueue (#40977)

上级 733d8168
......@@ -29,6 +29,8 @@ namespace paddle {
namespace framework {
namespace interpreter {
constexpr size_t kPrepareWorkQueueIdx = 2;
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) {
// NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
......@@ -47,34 +49,31 @@ using VariableIdMap = std::map<std::string, std::vector<int>>;
void AsyncWorkQueue::PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
auto p = std::make_shared<
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
atomic_deps_ = p->get_future();
queue_group_->AddTask(2, [&dependecy_count, p] {
auto* op_deps =
new std::vector<std::atomic<size_t>>(dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i];
}
VLOG(4) << "AtomicDeps:" << op_deps << " " << (*op_deps).size();
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(op_deps));
});
atomic_deps_ =
queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] {
auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i];
}
VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
return op_deps;
});
}
void AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
auto p = std::make_shared<
std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
atomic_var_ref_ = p->get_future();
queue_group_->AddTask(2, [&vec_meta_info, p] {
auto* var_ref = new std::vector<std::atomic<size_t>>(vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_;
}
VLOG(4) << "AtomicVarRef:" << var_ref << " " << (*var_ref).size();
p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(var_ref));
});
atomic_var_ref_ =
queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_;
}
VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
return var_ref;
});
}
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
......
......@@ -17,6 +17,7 @@
#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <thread>
#include <type_traits>
#include <unordered_map>
......@@ -60,16 +61,17 @@ class ThreadDataRegistry {
private:
// types
using LockType = std::shared_timed_mutex;
class ThreadDataHolder;
class ThreadDataRegistryImpl {
public:
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
std::lock_guard<std::mutex> lock(lock_);
std::lock_guard<LockType> lock(lock_);
tid_map_[tid] = tls_obj;
}
void UnregisterData(uint64_t tid) {
std::lock_guard<std::mutex> lock(lock_);
std::lock_guard<LockType> lock(lock_);
tid_map_.erase(tid);
}
......@@ -77,7 +79,7 @@ class ThreadDataRegistry {
std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy;
std::lock_guard<std::mutex> lock(lock_);
std::shared_lock<LockType> lock(lock_);
data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData());
......@@ -88,7 +90,7 @@ class ThreadDataRegistry {
std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref;
std::lock_guard<std::mutex> lock(lock_);
std::shared_lock<LockType> lock(lock_);
data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
......@@ -97,7 +99,7 @@ class ThreadDataRegistry {
}
private:
std::mutex lock_;
LockType lock_;
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册