未验证 提交 c21597cf 编写于 作者: Y Yu Yang 提交者: GitHub

fix(PE): use shared_ptr<BlockingQueue> for cross thread communication (#14136)

It seems that the blocking queue might be destroyed early than Run
method complete. It might because the Run method throw some unhandled
exception. However, it should be shared_ptr when multthread access an
resource. So change BlockingQueue as a shared_ptr.

test=develop
上级 5cc99c47
...@@ -92,13 +92,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -92,13 +92,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
size_t num_complete = 0; size_t num_complete = 0;
remaining_ = 0; remaining_ = 0;
BlockingQueue<size_t> complete_q; auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, &complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q.Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
int remaining = 0; int remaining = 0;
while (true) { while (true) {
...@@ -107,7 +107,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -107,7 +107,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
break; break;
} }
for (int i = 0; i < remaining; ++i) { for (int i = 0; i < remaining; ++i) {
complete_q.Pop(); complete_q->Pop();
} }
} }
exception_.ReThrow(); exception_.ReThrow();
...@@ -120,7 +120,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -120,7 +120,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q) { OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_; ++remaining_;
this->pool_.enqueue([=] { this->pool_.enqueue([=] {
OpHandleBase *op_to_run = op; OpHandleBase *op_to_run = op;
...@@ -144,7 +145,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -144,7 +145,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
if (op_to_run == nullptr) { if (op_to_run == nullptr) {
op_to_run = pending_op; op_to_run = pending_op;
} else { } else {
this->RunOpAsync(op_deps, pending_op, complete_q); RunOpAsync(op_deps, pending_op, complete_q);
} }
} }
} }
...@@ -156,8 +157,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -156,8 +157,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
} }
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = pool_.enqueue([&] { atomic_op_deps_ = pool_.enqueue([&] {
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps = auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
new std::unordered_map<OpHandleBase *, std::atomic<int>>;
for (auto &pair : op_deps_) { for (auto &pair : op_deps_) {
(*op_deps)[pair.first] = pair.second; (*op_deps)[pair.first] = pair.second;
} }
......
...@@ -50,7 +50,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -50,7 +50,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::atomic<int> remaining_; std::atomic<int> remaining_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q); OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
......
...@@ -39,7 +39,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -39,7 +39,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr)); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars; auto ready_vars = std::make_shared<BlockingQueue<VarHandleBase *>>();
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule // streams from multiple GPUs, it's faster to buffer them and schedule
...@@ -51,12 +51,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,12 +51,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get());
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var.get());
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
...@@ -73,12 +73,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -73,12 +73,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data); &pending_vars, ready_vars.get(), &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
running_ops_++; running_ops_++;
RunOp(&ready_vars, op); RunOp(ready_vars, op);
} }
set.clear(); set.clear();
}; };
...@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear(); exception_holder_.Clear();
event.reset(nullptr); event.reset(nullptr);
// Step 3. Execution // Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
...@@ -103,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -103,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable // 2. Find ready variable
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
...@@ -133,7 +132,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -133,7 +132,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_.get(), &fetch_ops);
...@@ -206,7 +204,8 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -206,7 +204,8 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
......
...@@ -51,7 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,7 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() {}
private: private:
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册