From f52d78d18938b140673b30ce40dde95c4019a57f Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 12 Jun 2018 20:43:43 +0800 Subject: [PATCH] update by comment --- .../details/multi_devices_graph_builder.cc | 162 +++++++++--------- .../details/multi_devices_graph_builder.h | 14 +- 2 files changed, 93 insertions(+), 83 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cc07cfc552..cf5968993f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( for (auto &p : params) { grad_names_.insert(GradVarName(p)); } + balance_vars_.resize(places_.size(), 0); } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, @@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( checker(op.InputArgumentNames(), recv_vars); } +size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( + const std::vector &var_names) const { + int64_t numel_sum = 0; + for (auto var_name : var_names) { + auto var_desc = all_vars_.at(var_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GT(numel, 0); + numel_sum += numel; + } + + auto smallest = + std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); + size_t dev_id = + static_cast(std::distance(std::begin(balance_vars_), smallest)); + balance_vars_[dev_id] += numel_sum; + return dev_id; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { - std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { - all_vars[var->Name()] = var; + all_vars_.emplace(var->Name(), var); } auto graph = new SSAGraph(); @@ -165,71 +185,15 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bcast_var_name_set.resize(places_.size()); size_t cur_device_id = 0; - std::vector balance_grads(places_.size(), 0); - - auto get_appropriate_dev = [&](std::vector var_names) -> size_t { - int64_t numel_all = 0; - for (auto var_name : var_names) { - auto var_desc = all_vars.at(var_name); - PADDLE_ENFORCE_NOT_NULL(var_desc); - auto dim = framework::make_ddim(var_desc->GetShape()); - int64_t numel = framework::product(dim); - PADDLE_ENFORCE_GT(numel, 0); - numel_all += numel; - } - - auto smallest = - std::min_element(std::begin(balance_grads), std::end(balance_grads)); - size_t dev_id = - static_cast(std::distance(std::begin(balance_grads), smallest)); - balance_grads[dev_id] += numel_all; - return dev_id; - }; - bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (boost::get( op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - // append rpc op if program is distributed trainer main program. - // always use the first device - if (op->Type() == "send_vars") { - int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); - if (op_dev_id == -1) { - op_dev_id = get_appropriate_dev(op->InputArgumentNames()); - for (auto &varname : op->InputArgumentNames()) { - var_name_on_devices_.emplace(varname, op_dev_id); - } - } - CreateRPCOp(&result, *op, op_dev_id); - } else if (op->Type() == "recv") { - int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); - for (auto &varname : op->OutputArgumentNames()) { - var_name_on_devices_.emplace(varname, op_dev_id); - } - CreateRPCOp(&result, *op, op_dev_id); - } else { - // send_barrier and fetch_barrier op would run on device 0 - CreateRPCOp(&result, *op, 0); - } + CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - if (op->Type() == "split_byref") { - int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); - for (auto &varname : op->OutputArgumentNames()) { - var_name_on_devices_.emplace(varname, op_dev_id); - } - CreateDistTrainOp(&result, *op, op_dev_id); - } else if (op->Type() == "concat") { - int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); - PADDLE_ENFORCE(op_dev_id != -1, - "can not find right place to concatenate received var."); - CreateDistTrainOp(&result, *op, op_dev_id); - } else { - PADDLE_ENFORCE( - "the distribute training related op should be in [split_byref, " - "concat]."); - } + CreateDistTrainOp(&result, *op); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -267,13 +231,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: - cur_device_id = get_appropriate_dev({g_name}); + cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); var_name_on_devices_.emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(all_vars, g_name)) { + if (IsSparseGradient(g_name)) { CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { @@ -310,11 +274,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } -bool MultiDevSSAGraphBuilder::IsSparseGradient( - const std::unordered_map &all_vars, - const std::string &og) const { - PADDLE_ENFORCE(all_vars.count(og) != 0); - if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { +bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { + PADDLE_ENFORCE(all_vars_.count(og) != 0); + if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { return true; } return false; @@ -498,18 +460,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, } void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, - const OpDesc &op, - int place_id) const { - CreateComputationalOp(result, op, place_id); + const OpDesc &op) const { + int op_dev_id = -1; + if (op.Type() == "split_byref") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); + for (auto &varname : op.InputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } else if (op.Type() == "concat") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + } else { + PADDLE_ENFORCE( + "the distribute training related op should be in [split_byref, " + "concat]."); + } + + PADDLE_ENFORCE(op_dev_id != -1, + "can not find right place for distributed op: %s", op.Type()); + + CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); } } -void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op, - int device_id) const { - result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id], - op.Type(), places_[device_id])); +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, + const OpDesc &op) const { + int op_dev_id = -1; + if (op.Type() == "send") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + // the variable name which contains .block means it was splited by + // split_byref op + // so that we can balance the variable blocks to all the pserver instances. + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && + op.InputArgumentNames()[0].find(".block") == std::string::npos) { + op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); + for (auto &varname : op.InputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } + } else if (op.Type() == "recv") { + op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } else { + // send_barrier and fetch_barrier op can be scheduled on device 0 + op_dev_id = 0; + } + + PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", + op.Type()); + + result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], + op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { ConnectOp(result, result->ops_.back().get(), "send"); @@ -525,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op, "send, send_barrier. recv, fetch_barrier]"); } - // TODO(Yancey1989): schedule rpc op on different place may - // increate throughput - CreateOpHandleIOs(result, op, device_id); + CreateOpHandleIOs(result, op, op_dev_id); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index a8a2045288..eb1c07630a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,9 +65,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const; - void CreateDistTrainOp(SSAGraph *result, const OpDesc &op, - int place_id) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; + void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. @@ -105,13 +104,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; - bool IsSparseGradient( - const std::unordered_map &all_vars, - const std::string &og) const; + bool IsSparseGradient(const std::string &og) const; + + size_t GetAppropriateDeviceID( + const std::vector &var_names) const; private: BuildStrategy strategy_; + mutable std::unordered_map all_vars_; mutable std::unordered_map var_name_on_devices_; + mutable std::vector balance_vars_; void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; -- GitLab