From 5fa535b71785cc2abc58f3e0f76a2e7c73dfd497 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 19 Mar 2018 19:09:45 +0800 Subject: [PATCH] Wait all thread done --- paddle/fluid/framework/parallel_executor.cc | 16 ++++++++++++---- paddle/fluid/framework/parallel_executor.h | 7 ++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ac2c87845..938f4317b 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -699,8 +699,11 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, pending_ops.insert({op, op->inputs_.size()}); } + std::vector> op_threads; + op_threads.reserve(pending_ops.size() + to_run.size()); + for (auto *op : to_run) { - RunOp(pending_vars, op); + op_threads.emplace_back(RunOp(pending_vars, op)); } while (!pending_ops.empty()) { @@ -731,15 +734,20 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } for (auto *op : to_run) { pending_ops.erase(op); - RunOp(pending_vars, op); + op_threads.emplace_back(RunOp(pending_vars, op)); } } + + for (auto &t : op_threads) { + t.get(); // Join all workers + } + fetch_ops.clear(); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetched_data->tensors_; } -void ParallelExecutor::RunOp( +std::future ParallelExecutor::RunOp( std::unordered_map &pending_vars, OpHandle *op) const { std::vector *ready_buffer = new std::vector(); @@ -760,7 +768,7 @@ void ParallelExecutor::RunOp( LOG(FATAL) << "Unknown exception catched"; } }; - member_->pool_.enqueue(op_run); + return member_->pool_.enqueue(op_run); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b6fa6fb2d..badf7c5ea 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include #include - #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" @@ -81,8 +81,9 @@ class ParallelExecutor { void BuildNCCLCommunicator() const; - void RunOp(std::unordered_map& pending_vars, - OpHandle* op) const; + std::future RunOp( + std::unordered_map& pending_vars, + OpHandle* op) const; void PolishGraphToSupportDataHarzaeds() const; }; -- GitLab