diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ac2c87845341bf48104c9f0c62edc394b85218bf..938f4317b1d41e66d9e4dc72e4d704f2add10762 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -699,8 +699,11 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, pending_ops.insert({op, op->inputs_.size()}); } + std::vector> op_threads; + op_threads.reserve(pending_ops.size() + to_run.size()); + for (auto *op : to_run) { - RunOp(pending_vars, op); + op_threads.emplace_back(RunOp(pending_vars, op)); } while (!pending_ops.empty()) { @@ -731,15 +734,20 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } for (auto *op : to_run) { pending_ops.erase(op); - RunOp(pending_vars, op); + op_threads.emplace_back(RunOp(pending_vars, op)); } } + + for (auto &t : op_threads) { + t.get(); // Join all workers + } + fetch_ops.clear(); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetched_data->tensors_; } -void ParallelExecutor::RunOp( +std::future ParallelExecutor::RunOp( std::unordered_map &pending_vars, OpHandle *op) const { std::vector *ready_buffer = new std::vector(); @@ -760,7 +768,7 @@ void ParallelExecutor::RunOp( LOG(FATAL) << "Unknown exception catched"; } }; - member_->pool_.enqueue(op_run); + return member_->pool_.enqueue(op_run); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b6fa6fb2d87f48cda432bbb3939b615a2a0593ea..badf7c5ea746b0677b624ec84389d2e353b7e736 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include #include - #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" @@ -81,8 +81,9 @@ class ParallelExecutor { void BuildNCCLCommunicator() const; - void RunOp(std::unordered_map& pending_vars, - OpHandle* op) const; + std::future RunOp( + std::unordered_map& pending_vars, + OpHandle* op) const; void PolishGraphToSupportDataHarzaeds() const; };