From 193c0a7e4333ca7e403089ef1f9e66c79d56c68a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 15 Mar 2018 17:27:42 +0800 Subject: [PATCH] Handle var hazard --- paddle/fluid/framework/parallel_executor.cc | 137 +++++++++++++++++--- 1 file changed, 121 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 7af5cc075c..e98fedb68d 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -28,42 +28,79 @@ namespace framework { struct OpHandle; -struct VarHandle { +struct VarHandleBase { + virtual ~VarHandleBase() {} + virtual std::string DebugString() const = 0; + + OpHandle *generated_op_; + std::vector pending_ops_; +}; + +struct VarHandle : public VarHandleBase { + std::string DebugString() const override { + std::stringstream ss; + ss << name_ << ":" << place_; + return ss.str(); + } + size_t version_; std::string name_; platform::Place place_; +}; - OpHandle *generated_op_; - - std::vector pending_ops_; +struct DependencyVarHandle : public VarHandleBase { + std::string DebugString() const override { return "Deps var"; } }; struct OpHandle { - std::vector inputs_; - std::vector outputs_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map + dev_ctx_; std::string DebugString() { std::stringstream ss; ss << "("; for (auto *var : inputs_) { - ss << var->name_ << ":" << var->place_ << ", "; + ss << var->DebugString() << ", "; } ss << ") --> ("; for (auto *var : outputs_) { - ss << var->name_ << ":" << var->place_ << ", "; + ss << var->DebugString() << ", "; } ss << ")\n"; return ss.str(); } virtual ~OpHandle() {} + + virtual void Run() {} + virtual void Wait() {} }; struct ComputationOpHandle : public OpHandle { std::unique_ptr op_; + Scope *scope_; + platform::Place place_; - explicit ComputationOpHandle(const OpDesc &op_desc) - : op_(framework::OpRegistry::CreateOp(op_desc)) {} + explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + platform::Place place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + scope_(scope), + place_(place) {} + + void Run() override { + // Wait other op if necessary + auto *cur_ctx = dev_ctx_[place_]; + for (auto *in : inputs_) { + if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { + in->generated_op_->Wait(); + } + } + + op_->Run(*scope_, place_); + } }; struct ScaleLossGradOpHandle : public OpHandle {}; @@ -122,12 +159,27 @@ class ParallelExecutorPrivate { #endif + platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { + if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { + return const_cast( + platform::DeviceContextPool::Instance().Get(place)); + } else { +#ifdef PADDLE_WITH_CUDA + return GetNCCLCtx(place).ctx_.get(); +#else + PADDLE_THROW("Not compiled with CUDA") +#endif + } + } + platform::Place main_place_; std::unordered_map>, platform::PlaceHash> vars_; + std::unordered_set> dep_vars_; + std::vector> ops_; ThreadPool pool_; @@ -170,7 +222,7 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::ConstructDependencyGraph( const std::unordered_set ¶ms, const ProgramDesc &main_program, const std::string &loss_var_name) const { - std::unordered_set grads; + std::unordered_set grads; for (auto &each_param : params) { grads.insert(each_param + "@GRAD"); } @@ -188,8 +240,11 @@ void ParallelExecutor::ConstructDependencyGraph( } for (auto &pair : member_->local_scopes_) { - member_->ops_.emplace_back(new ComputationOpHandle(*op)); + member_->ops_.emplace_back( + new ComputationOpHandle(*op, pair.second, pair.first)); auto *op_handle = member_->ops_.back().get(); + op_handle->dev_ctx_[pair.first] = const_cast( + platform::DeviceContextPool::Instance().Get(pair.first)); auto var_names = op->InputArgumentNames(); @@ -210,8 +265,11 @@ void ParallelExecutor::ConstructDependencyGraph( if (var_names.size() == 1 && var_names[0] == loss_var_name) { // Insert ScaleCost OpHandle member_->ops_.emplace_back(new ScaleLossGradOpHandle()); - op_handle = member_->ops_.back().get(); + + op_handle->dev_ctx_[pair.first] = + member_->CommunicationDevCtx(pair.first); + auto &place = pair.first; VarHandle *loss = GetVarHandle(loss_var_name, place); loss->pending_ops_.emplace_back(op_handle); @@ -251,11 +309,54 @@ void ParallelExecutor::ConstructDependencyGraph( var.name_ = og; var.version_ = vars.size() - 1; op_handle->outputs_.emplace_back(&var); + + for (auto &pair : member_->local_scopes_) { + op_handle->dev_ctx_[pair.first] = + member_->CommunicationDevCtx(pair.first); + } } } } } } + + /** + * Dependency graph has been constructed. However, there are still data + * harzaeds need to be handled. + * + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ + + for (auto &place_pair : member_->vars_) { + for (auto &name_pair : place_pair.second) { + if (name_pair.second.size() <= 1) { + return; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + auto *write_op = it_new->second.generated_op_; + auto &read_ops = it_old->second.pending_ops_; + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + + auto *dep_var = new DependencyVarHandle(); + dep_var->generated_op_ = read_op; + read_op->outputs_.emplace_back(dep_var); + + dep_var->pending_ops_.emplace_back(write_op); + write_op->inputs_.emplace_back(dep_var); + member_->dep_vars_.emplace(dep_var); + } + } + } + } } void ParallelExecutor::GenerateVar(OpHandle *op_handle, @@ -349,7 +450,7 @@ std::vector ParallelExecutor::Run( const std::vector &fetch_tensors) { // Version --> VarHandle - std::unordered_map pending_vars; + std::unordered_map pending_vars; std::unordered_map pending_ops; for (auto &place_pair : member_->vars_) { @@ -361,12 +462,16 @@ std::vector ParallelExecutor::Run( } } + for (auto &var : member_->dep_vars_) { + pending_vars[var.get()] = var->generated_op_ == nullptr; + } + for (auto &op : member_->ops_) { pending_ops.insert({op.get(), op->inputs_.size()}); } while (!pending_ops.empty()) { - VarHandle *ready_var = nullptr; + VarHandleBase *ready_var = nullptr; for (auto &pair : pending_vars) { if (pair.second) { ready_var = pair.first; @@ -400,7 +505,7 @@ std::vector ParallelExecutor::Run( auto op_run = [ready_buffer, op] { // TODO(yy) Check Previous Op has same dev ctx. - LOG(INFO) << "Run " << op->DebugString(); + op->Run(); for (auto *ready : ready_buffer) { *ready = true; } -- GitLab