From 79989c902530fcaf525161b8d1b3eaee9d634291 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 20:17:11 +0800 Subject: [PATCH] Add SSA builder --- paddle/fluid/framework/parallel_executor.cc | 369 +++++++++++--------- paddle/fluid/framework/parallel_executor.h | 4 - 2 files changed, 199 insertions(+), 174 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 37bfdc0df52..b2be3d13055 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -43,79 +43,211 @@ struct SSAGraph { std::vector> ops_; }; -/** - * We only handle write after read(WAR), since it should not have a write - * after write in program. If there are write after write operators, we need - * prune them. - * - * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) - */ -static void PolishGraphToSupportDataHazards(SSAGraph *graph) { - for (auto &var_map : graph->vars_) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - return; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = it_new->second.generated_op_; - auto &read_ops = it_old->second.pending_ops_; - auto *ex_write_op = it_old->second.generated_op_; - - if (ex_write_op == nullptr) { // Nobody write this var. - continue; +class SSAGraphBuilder { + public: + virtual ~SSAGraphBuilder() {} + virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0; + + protected: + /** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ + static void PolishGraphToSupportDataHazards(SSAGraph *graph) { + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + return; } - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + auto *write_op = it_new->second.generated_op_; + auto &read_ops = it_old->second.pending_ops_; + auto *ex_write_op = it_old->second.generated_op_; + + if (ex_write_op == nullptr) { // Nobody write this var. continue; } - auto *dep_var = new DummyVarHandle(); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->dep_vars_.emplace(dep_var); + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + + auto *dep_var = new DummyVarHandle(); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->dep_vars_.emplace(dep_var); + } } } } } -} -static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, - const std::string &each_var_name, - const platform::Place &place, - size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; - auto &var_holder = var_holders[each_var_name]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - auto &init_var = var_holder[0]; - init_var.place_ = place; - init_var.name_ = each_var_name; - init_var.generated_op_ = nullptr; - init_var.version_ = 0; - var = &init_var; - } else { - var = &var_holder.rbegin()->second; + static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset) { + auto &var_holders = graph->vars_[place_offset]; + auto &var_holder = var_holders[each_var_name]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + auto &init_var = var_holder[0]; + init_var.place_ = place; + init_var.name_ = each_var_name; + init_var.generated_op_ = nullptr; + init_var.version_ = 0; + var = &init_var; + } else { + var = &var_holder.rbegin()->second; + } + return var; } - return var; -} -static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { - auto &vars = graph->vars_[place_offset][each_var_name]; - size_t version = vars.size(); - auto &var = vars[version]; - var.version_ = version; - var.name_ = each_var_name; - var.place_ = place; - op_handle->AddOutput(&var); -} + static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset) { + auto &vars = graph->vars_[place_offset][each_var_name]; + size_t version = vars.size(); + auto &var = vars[version]; + var.version_ = version; + var.name_ = each_var_name; + var.place_ = place; + op_handle->AddOutput(&var); + } +}; + +class MultiDevSSAGraphBuilder : public SSAGraphBuilder { + public: + MultiDevSSAGraphBuilder(const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs) + : loss_var_name_(loss_var_name), + places_(places), + local_scopes_(local_scopes), + nccl_ctxs_(nccl_ctxs) { + for (auto &p : params) { + grad_names_.insert(GradVarName(p)); + } + } + + void Build(const ProgramDesc &program, SSAGraph *graph) const override { + SSAGraph &result = *graph; + result.vars_.resize(places_.size()); + + bool is_forwarding = true; + for (auto *op : program.Block(0).AllOps()) { + bool change_forward = false; + if (!is_forwarding) { + // FIXME(yy): Do not hard code like this + if (op->OutputArgumentNames().size() == 1 && + op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) { + continue; // Drop fill 1. for backward coeff; + } + } + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + + result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); + auto *op_handle = result.ops_.back().get(); + op_handle->dev_ctx_[p] = const_cast( + platform::DeviceContextPool::Instance().Get(p)); + + auto var_names = op->InputArgumentNames(); + + for (auto &each_var_name : var_names) { + VarHandle *var = + CreateOrGetLatestVarHandle(&result, each_var_name, p, i); + op_handle->AddInput(var); + } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(&result, op_handle, each_var_name, p, i); + } + + if (is_forwarding) { + if (var_names.size() == 1 && var_names[0] == loss_var_name_) { + // Insert ScaleCost OpHandle + op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, + nccl_ctxs_->DevCtx(p)); + result.ops_.emplace_back(op_handle); + + // FIXME: Currently ScaleLossGradOp only use device_count as scale + // factor. So it does not depend on any other operators. + // VarHandle *loss = GetVarHandle(loss_var_name, place); + // loss->pending_ops_.emplace_back(op_handle); + // op_handle->inputs_.emplace_back(loss); + + CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, + i); + change_forward = true; + } + } + } + + if (change_forward) { + is_forwarding = false; + } + + if (!is_forwarding) { + auto var_names = op->OutputArgumentNames(); + for (auto &og : var_names) { + if (grad_names_.count(og) != 0) { // is param grad + // Insert NCCL AllReduce Op + result.ops_.emplace_back( + new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); + auto *op_handle = result.ops_.back().get(); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + auto &vars = result.vars_[i][og]; + + if (vars.empty()) { // This device has no data. continue. + continue; + } + auto *prev_grad = &vars[vars.size() - 1]; + op_handle->AddInput(prev_grad); + + auto &var = vars[vars.size()]; + var.place_ = p; + var.name_ = og; + var.version_ = vars.size() - 1; + + op_handle->AddOutput(&var); + } + } + } + } + } + + /* + Dependency graph has been constructed. However, there are still data + harzaeds need to be handled. + */ + PolishGraphToSupportDataHazards(&result); + } + + private: + std::string loss_var_name_; + const std::vector &places_; + const std::vector &local_scopes_; + platform::NCCLContextMap *nccl_ctxs_; + + std::unordered_set grad_names_; +}; class ParallelExecutorPrivate { public: @@ -123,9 +255,7 @@ class ParallelExecutorPrivate { const std::vector &places) : places_(places), fetch_dev_ctxs_(places), - pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) { - graph_.vars_.resize(places.size()); - } + pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} std::vector places_; platform::DeviceContextPool fetch_dev_ctxs_; @@ -199,7 +329,10 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp - ConstructDependencyGraph(params, main_program, loss_var_name); + MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params, + member_->local_scopes_, + member_->nccl_ctxs_.get()); + builder.Build(main_program, &member_->graph_); // Step 3. Create vars in each scope; for (auto *scope : member_->local_scopes_) { @@ -213,110 +346,6 @@ ParallelExecutor::ParallelExecutor( } } -void ParallelExecutor::ConstructDependencyGraph( - const std::unordered_set ¶ms, - const ProgramDesc &main_program, const std::string &loss_var_name) const { - std::unordered_set grads; - for (auto &each_param : params) { - grads.insert(each_param + "@GRAD"); - } - - bool is_forwarding = true; - for (auto *op : main_program.Block(0).AllOps()) { - bool change_forward = false; - if (!is_forwarding) { - // FIXME(yy): Do not hard code like this - if (op->OutputArgumentNames().size() == 1 && - op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") { - continue; // Drop fill 1. for backward coeff; - } - } - - for (size_t i = 0; i < member_->places_.size(); ++i) { - auto &p = member_->places_[i]; - auto *s = member_->local_scopes_[i]; - - member_->graph_.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); - auto *op_handle = member_->graph_.ops_.back().get(); - op_handle->dev_ctx_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); - - auto var_names = op->InputArgumentNames(); - - for (auto &each_var_name : var_names) { - VarHandle *var = - CreateOrGetLatestVarHandle(&member_->graph_, each_var_name, p, i); - op_handle->AddInput(var); - } - var_names = op->OutputArgumentNames(); - - for (auto &each_var_name : var_names) { - CreateOpOutput(&member_->graph_, op_handle, each_var_name, p, i); - } - - if (is_forwarding) { - if (var_names.size() == 1 && var_names[0] == loss_var_name) { - // Insert ScaleCost OpHandle - op_handle = - new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s, - p, member_->nccl_ctxs_->DevCtx(p)); - member_->graph_.ops_.emplace_back(op_handle); - - // FIXME: Currently ScaleLossGradOp only use device_count as scale - // factor. So it does not depend on any other operators. - // VarHandle *loss = GetVarHandle(loss_var_name, place); - // loss->pending_ops_.emplace_back(op_handle); - // op_handle->inputs_.emplace_back(loss); - - CreateOpOutput(&member_->graph_, op_handle, loss_var_name + "@GRAD", - p, i); - change_forward = true; - } - } - } - - if (change_forward) { - is_forwarding = false; - } - - if (!is_forwarding) { - auto var_names = op->OutputArgumentNames(); - for (auto &og : var_names) { - if (grads.count(og) != 0) { // is param grad - // Insert NCCL AllReduce Op - member_->graph_.ops_.emplace_back(new NCCLAllReduceOpHandle( - member_->local_scopes_, member_->places_, *member_->nccl_ctxs_)); - auto *op_handle = member_->graph_.ops_.back().get(); - - for (size_t i = 0; i < member_->places_.size(); ++i) { - auto &p = member_->places_[i]; - auto &vars = member_->graph_.vars_[i][og]; - - if (vars.empty()) { // This device has no data. continue. - continue; - } - auto *prev_grad = &vars[vars.size() - 1]; - op_handle->AddInput(prev_grad); - - auto &var = vars[vars.size()]; - var.place_ = p; - var.name_ = og; - var.version_ = vars.size() - 1; - - op_handle->AddOutput(&var); - } - } - } - } - } - - /* - Dependency graph has been constructed. However, there are still data - harzaeds need to be handled. - */ - PolishGraphToSupportDataHazards(&member_->graph_); -} - void ParallelExecutor::BCastParamsToGPUs( const ProgramDesc &startup_program) const { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 8c91c45d146..39a1c51b9e7 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -47,10 +47,6 @@ class ParallelExecutor { void BCastParamsToGPUs(const ProgramDesc& startup_program) const; - void ConstructDependencyGraph(const std::unordered_set& params, - const ProgramDesc& main_program, - const std::string& loss_var_name) const; - void BuildNCCLCommunicator() const; }; -- GitLab