From 4c3361cda826f9ca2e5c96637b1481211f2bba63 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 24 Mar 2018 13:39:57 +0800 Subject: [PATCH] Extract GraphExecutor --- paddle/fluid/framework/parallel_executor.cc | 323 ++++++++++++-------- 1 file changed, 194 insertions(+), 129 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 4ebb89181cd..78ef66be514 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -24,42 +24,184 @@ limitations under the License. */ namespace paddle { namespace framework { -class ParallelExecutorPrivate { +using details::DummyVarHandle; +using details::FetchOpHandle; +using details::OpHandleBase; +using details::SSAGraph; +using details::VarHandleBase; + +class SSAGraphExecutor { + DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); + public: - explicit ParallelExecutorPrivate(size_t num_threads, - const std::vector &places) - : places_(places), - fetch_dev_ctxs_(places), - pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} + explicit SSAGraphExecutor(SSAGraph *graph) : graph_(*graph) {} - std::vector places_; - platform::DeviceContextPool fetch_dev_ctxs_; - std::vector local_scopes_; - Scope *global_scope_; + virtual ~SSAGraphExecutor() {} - std::unique_ptr nccl_ctxs_; + virtual void Run(Scope *global_scope, + const std::vector &fetch_tensors, + const std::string &fetch_list_name) = 0; - details::SSAGraph graph_; + protected: + SSAGraph &graph_; +}; - // Use a simpler thread pool, might be faster. - std::unique_ptr pool_; +class ThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, + const std::vector &local_scopes, + const std::vector &places, + SSAGraph *graph) + : SSAGraphExecutor(graph), + pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), + local_scopes_(local_scopes), + places_(places), + fetch_ctxs_(places), + use_event_(use_event) {} + + void Run(Scope *global_scope, const std::vector &fetch_tensors, + const std::string &fetch_list_name) override { + std::unordered_map pending_ops; + std::unordered_map> pending_vars; + std::unordered_set ready_ops; + + auto InsertPendingVar = [&pending_vars](VarHandleBase &var) { + pending_vars[&var] = var.generated_op_ == nullptr; + }; - std::unique_ptr exception_; + auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { + pending_ops.insert({&op_instance, op_instance.inputs_.size()}); + }; + + // Transform SSAGraph to pending_ops & pending_vars + for (auto &var_map : graph_.vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + InsertPendingVar(version_pair.second); + } + } + } + for (auto &var : graph_.dep_vars_) { + InsertPendingVar(*var); + } + + for (auto &op : graph_.ops_) { + if (op->inputs_.empty()) { // Special case, Op has no input. + ready_ops.insert(op.get()); + } else { + InsertPendingOp(*op); + } + } + + // Step 2. Insert FetchOps + std::vector fetch_ops; + std::vector dummy_vars; + FeedFetchList fetch_data(fetch_tensors.size()); + + std::unordered_map> fetched_vars; + + for (auto &fetch_var_name : fetch_tensors) { + for (auto &var_map : graph_.vars_) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); + } + } + } - void RunOp(bool use_event, - std::unordered_map> - &pending_vars, - details::OpHandleBase *op) { + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors[i]; + auto &vars = fetched_vars[var_name]; + fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); + details::FetchOpHandle *op = &fetch_ops.back(); + + // FIXME: Use new device context + for (auto &p : places_) { + op->dev_ctx_[p] = fetch_ctxs_.Get(p); + } + + for (auto *var : vars) { + op->AddInput(var); + } + + dummy_vars.emplace_back(); + auto *var = &dummy_vars.back(); + var->generated_op_ = nullptr; + op->AddOutput(var); + InsertPendingVar(*var); + InsertPendingOp(*op); + } + + auto run_all_ready_ops = [&] { + for (auto *op : ready_ops) { + RunOp(pending_vars, op); + } + ready_ops.clear(); + }; + + // Step 3. Execution + while (!pending_vars.empty()) { + // 1. Run All Ready ops + run_all_ready_ops(); + + // 2. Find ready variable + VarHandleBase *ready_var = nullptr; + for (auto &pair : pending_vars) { + if (pair.second.load(std::memory_order_acquire)) { + ready_var = pair.first; + break; + } + } + + // if there is no variable ready + if (ready_var == nullptr) { + // FIXME use conditional var instead of busy wait. + // if there is an exception, throw it + if (exception_) { + throw * exception_; + } + // keep waiting the ready variables + continue; + } + + // 3. Remove the dependency of ready_var. + // Find the ready_ops after the ready_var. + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + ready_ops.insert(op); + } + } + // Keep loop until all vars are ready. + } + + // Wait FetchOps. + for (auto &fetch_op : fetch_ops) { + fetch_op.WaitAndMergeCPUTensors(); + } + + *global_scope->Var(fetch_list_name)->GetMutable() = + fetch_data; + } + + ~ThreadedSSAGraphExecutor() {} + + private: + void RunOp( + std::unordered_map> &pending_vars, + details::OpHandleBase *op) { std::vector *> *ready_buffer = new std::vector *>(); for (auto *var : op->outputs_) { ready_buffer->emplace_back(&pending_vars[var]); } - auto op_run = [ready_buffer, op, this, use_event] { + auto op_run = [ready_buffer, op, this] { try { VLOG(10) << op->DebugString(); - op->Run(use_event); + op->Run(use_event_); for (auto *ready : *ready_buffer) { ready->store(true, std::memory_order_release); } @@ -76,6 +218,31 @@ class ParallelExecutorPrivate { op_run(); } } + + private: + std::unique_ptr<::ThreadPool> pool_; + std::vector local_scopes_; + std::vector places_; + platform::DeviceContextPool fetch_ctxs_; + const bool use_event_; + std::unique_ptr exception_; +}; + +class ParallelExecutorPrivate { + public: + explicit ParallelExecutorPrivate(const std::vector &places) + : places_(places), fetch_dev_ctxs_(places) {} + + std::vector places_; + platform::DeviceContextPool fetch_dev_ctxs_; + std::vector local_scopes_; + Scope *global_scope_; + + std::unique_ptr nccl_ctxs_; + + details::SSAGraph graph_; + + std::unique_ptr executor_; }; ParallelExecutor::ParallelExecutor( @@ -83,7 +250,7 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) - : member_(new ParallelExecutorPrivate(num_threads, places)) { + : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; // Step 1. RunStartupProgram and Bcast the params to devs. @@ -109,6 +276,9 @@ ParallelExecutor::ParallelExecutor( member_->nccl_ctxs_.get()); builder.Build(main_program, &member_->graph_); + member_->executor_.reset(new ThreadedSSAGraphExecutor( + num_threads, true, member_->local_scopes_, places, &member_->graph_)); + // Step 3. Create vars in each scope; for (auto *scope : member_->local_scopes_) { for (auto *var : main_program.Block(0).AllVars()) { @@ -168,113 +338,8 @@ void ParallelExecutor::BuildNCCLCommunicator() const { void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { - bool use_event = true; - FeedFetchList fetched_data(fetch_tensors.size()); - // Version --> VarHandle - member_->exception_.reset(); - std::unordered_map> pending_vars; - std::unordered_map pending_ops; - std::vector dummy_vars; - - for (auto &var_map : member_->graph_.vars_) { - for (auto &name_pair : var_map) { - for (auto &version_pair : name_pair.second) { - pending_vars[&version_pair.second] = - version_pair.second.generated_op_ == nullptr; - } - } - } - - for (auto &var : member_->graph_.dep_vars_) { - pending_vars[var.get()] = var->generated_op_ == nullptr; - } - - std::vector to_run; - - for (auto &op : member_->graph_.ops_) { - if (op->inputs_.empty()) { // Special case, Op has no input. - to_run.emplace_back(op.get()); - } else { - pending_ops.insert({op.get(), op->inputs_.size()}); - } - } - - std::unordered_map> - fetched_vars; - - for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : member_->graph_.vars_) { - auto it = var_map.find(fetch_var_name); - if (it != var_map.end()) { - fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); - } - } - } - - std::vector fetch_ops; - - for (size_t i = 0; i < fetch_tensors.size(); ++i) { - auto &var_name = fetch_tensors[i]; - auto &vars = fetched_vars[var_name]; - fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_); - details::FetchOpHandle *op = &fetch_ops.back(); - - // FIXME: Use new device context - for (auto &p : member_->places_) { - op->dev_ctx_[p] = member_->fetch_dev_ctxs_.Get(p); - } - - for (auto *var : vars) { - op->AddInput(var); - } - - dummy_vars.emplace_back(); - auto *var = &dummy_vars.back(); - op->AddOutput(var); - pending_vars[var] = false; - - pending_ops.insert({op, op->inputs_.size()}); - } - - for (auto *op : to_run) { - member_->RunOp(use_event, pending_vars, op); - } - - while (!pending_vars.empty()) { - details::VarHandleBase *ready_var = nullptr; - for (auto &pair : pending_vars) { - if (pair.second.load(std::memory_order_acquire)) { - ready_var = pair.first; - } - } - if (ready_var == nullptr) { - // FIXME use conditional var instead of busy wait. - if (member_->exception_) { - throw * member_->exception_; - } - continue; - } - pending_vars.erase(ready_var); - to_run.clear(); - for (auto *op : ready_var->pending_ops_) { - auto &deps = pending_ops[op]; - --deps; - if (deps == 0) { - to_run.emplace_back(op); - } - } - for (auto *op : to_run) { - pending_ops.erase(op); - member_->RunOp(use_event, pending_vars, op); - } - } - - for (auto &fetch_op : fetch_ops) { - fetch_op.WaitAndMergeCPUTensors(); - } - - *member_->global_scope_->Var(fetched_var_name)->GetMutable() = - fetched_data; + member_->executor_->Run(member_->global_scope_, fetch_tensors, + fetched_var_name); } } // namespace framework -- GitLab