diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 938f4317b1d41e66d9e4dc72e4d704f2add10762..2fb274d3a56ac14e219bb7688ece3ead43d7c0ca 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -53,6 +53,10 @@ struct VarHandle : public VarHandleBase { platform::Place place_; }; +struct DummyVarHandle : public VarHandleBase { + std::string DebugString() const override { return "dummy"; } +}; + struct DependencyVarHandle : public VarHandleBase { std::string DebugString() const override { return "Dependency Variable"; } }; @@ -643,6 +647,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, member_->exception_.reset(); std::unordered_map pending_vars; std::unordered_map pending_ops; + std::vector dummy_vars; for (auto &place_pair : member_->vars_) { for (auto &name_pair : place_pair.second) { @@ -696,17 +701,21 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, var->pending_ops_.emplace(op); op->inputs_.emplace_back(var); } + + dummy_vars.emplace_back(); + auto *var = &dummy_vars.back(); + op->outputs_.emplace_back(var); + var->generated_op_ = op; + pending_vars[var] = false; + pending_ops.insert({op, op->inputs_.size()}); } - std::vector> op_threads; - op_threads.reserve(pending_ops.size() + to_run.size()); - for (auto *op : to_run) { - op_threads.emplace_back(RunOp(pending_vars, op)); + RunOp(pending_vars, op); } - while (!pending_ops.empty()) { + while (!pending_vars.empty()) { VarHandleBase *ready_var = nullptr; for (auto &pair : pending_vars) { if (pair.second) { @@ -715,12 +724,9 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } if (ready_var == nullptr) { // FIXME use conditional var instead of busy wait. - if (member_->exception_) { throw * member_->exception_; } - - VLOG(3) << pending_vars.size(); continue; } pending_vars.erase(ready_var); @@ -734,20 +740,16 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } for (auto *op : to_run) { pending_ops.erase(op); - op_threads.emplace_back(RunOp(pending_vars, op)); + RunOp(pending_vars, op); } } - for (auto &t : op_threads) { - t.get(); // Join all workers - } - fetch_ops.clear(); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetched_data->tensors_; } -std::future ParallelExecutor::RunOp( +void ParallelExecutor::RunOp( std::unordered_map &pending_vars, OpHandle *op) const { std::vector *ready_buffer = new std::vector(); @@ -768,7 +770,7 @@ std::future ParallelExecutor::RunOp( LOG(FATAL) << "Unknown exception catched"; } }; - return member_->pool_.enqueue(op_run); + member_->pool_.enqueue(op_run); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index badf7c5ea746b0677b624ec84389d2e353b7e736..8fe93fb62e1853d2b180a3cef1697354aec49a96 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -81,9 +81,8 @@ class ParallelExecutor { void BuildNCCLCommunicator() const; - std::future RunOp( - std::unordered_map& pending_vars, - OpHandle* op) const; + void RunOp(std::unordered_map& pending_vars, + OpHandle* op) const; void PolishGraphToSupportDataHarzaeds() const; };