提交 5fa535b7 编写于 作者: Y Yu Yang

Wait all thread done

上级 7bff02b2
...@@ -699,8 +699,11 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -699,8 +699,11 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_ops.insert({op, op->inputs_.size()}); pending_ops.insert({op, op->inputs_.size()});
} }
std::vector<std::future<void>> op_threads;
op_threads.reserve(pending_ops.size() + to_run.size());
for (auto *op : to_run) { for (auto *op : to_run) {
RunOp(pending_vars, op); op_threads.emplace_back(RunOp(pending_vars, op));
} }
while (!pending_ops.empty()) { while (!pending_ops.empty()) {
...@@ -731,15 +734,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -731,15 +734,20 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
RunOp(pending_vars, op); op_threads.emplace_back(RunOp(pending_vars, op));
} }
} }
for (auto &t : op_threads) {
t.get(); // Join all workers
}
fetch_ops.clear(); fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_; fetched_data->tensors_;
} }
void ParallelExecutor::RunOp( std::future<void> ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars, std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>(); std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
...@@ -760,7 +768,7 @@ void ParallelExecutor::RunOp( ...@@ -760,7 +768,7 @@ void ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched"; LOG(FATAL) << "Unknown exception catched";
} }
}; };
member_->pool_.enqueue(op_run); return member_->pool_.enqueue(op_run);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <future>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -81,8 +81,9 @@ class ParallelExecutor { ...@@ -81,8 +81,9 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars, std::future<void> RunOp(
OpHandle* op) const; std::unordered_map<VarHandleBase*, GuardedBool>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHarzaeds() const; void PolishGraphToSupportDataHarzaeds() const;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册