diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 78356cb1be3bd089c26dde663275e2c8109df951..cf5968993fd4f31f8b6ff6ae482b1fe50310a00f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( for (auto &p : params) { grad_names_.insert(GradVarName(p)); } + balance_vars_.resize(places_.size(), 0); } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, @@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( checker(op.InputArgumentNames(), recv_vars); } +size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( + const std::vector &var_names) const { + int64_t numel_sum = 0; + for (auto var_name : var_names) { + auto var_desc = all_vars_.at(var_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GT(numel, 0); + numel_sum += numel; + } + + auto smallest = + std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); + size_t dev_id = + static_cast(std::distance(std::begin(balance_vars_), smallest)); + balance_vars_[dev_id] += numel_sum; + return dev_id; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { - std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { - all_vars[var->Name()] = var; + all_vars_.emplace(var->Name(), var); } auto graph = new SSAGraph(); @@ -161,35 +181,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto send_vars = FindDistTrainSendVars(program); auto recv_vars = FindDistTrainRecvVars(program); - std::vector> var_name_on_devices; std::vector> bcast_var_name_set; - var_name_on_devices.resize(places_.size()); bcast_var_name_set.resize(places_.size()); size_t cur_device_id = 0; - std::vector balance_grads(places_.size(), 0); - - auto get_appropriate_dev = [&](std::string &g_name) -> size_t { - auto var_desc = all_vars.at(g_name); - PADDLE_ENFORCE_NOT_NULL(var_desc); - auto dim = framework::make_ddim(var_desc->GetShape()); - int64_t numel = framework::product(dim); - PADDLE_ENFORCE_GE(numel, 0); - auto smallest = - std::min_element(std::begin(balance_grads), std::end(balance_grads)); - size_t dev_id = - static_cast(std::distance(std::begin(balance_grads), smallest)); - balance_grads[dev_id] += numel; - return dev_id; - }; - bool is_forwarding = true; + for (auto *op : program.Block(0).AllOps()) { if (boost::get( op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - // append rpc op if program is distributed trainer main program. - // always use the first device CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { CreateDistTrainOp(&result, *op); @@ -201,13 +202,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); + int op_dev_id = GetOpDeviceID(*op); if (op_dev_id == -1) { // var on all device CreateComputationalOps(&result, *op, places_.size()); } else { CreateComputationalOp(&result, *op, op_dev_id); for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices[op_dev_id].emplace(var_name); + var_name_on_devices_.emplace(var_name, op_dev_id); } } if (!is_forwarding && places_.size() > 1) { @@ -230,13 +231,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: - cur_device_id = get_appropriate_dev(g_name); + cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - var_name_on_devices[cur_device_id].emplace(g_name); + var_name_on_devices_.emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(all_vars, g_name)) { + if (IsSparseGradient(g_name)) { CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { @@ -273,11 +274,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } -bool MultiDevSSAGraphBuilder::IsSparseGradient( - const std::unordered_map &all_vars, - const std::string &og) const { - PADDLE_ENFORCE(all_vars.count(og) != 0); - if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { +bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { + PADDLE_ENFORCE(all_vars_.count(og) != 0); + if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { return true; } return false; @@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID( - const std::vector> &var_name_on_devices, - const OpDesc &op) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } - int var_dev_id = -1; - for (auto &var_name : op.InputArgumentNames()) { - if (var_dev_id != -1) break; - for (size_t i = 0; i < var_name_on_devices.size(); ++i) { - if (var_name_on_devices[i].count(var_name)) { - var_dev_id = static_cast(i); - break; - } + for (auto &varname : op.InputArgumentNames()) { + int dev_id = GetVarDeviceID(varname); + if (dev_id != -1) { + return dev_id; } } - return var_dev_id; + return -1; +} + +int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { + auto got = var_name_on_devices_.find(varname); + return got == var_name_on_devices_.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { @@ -463,7 +461,30 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const { - CreateComputationalOp(result, op, 0); + int op_dev_id = -1; + if (op.Type() == "split_byref") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); + for (auto &varname : op.InputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } else if (op.Type() == "concat") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + } else { + PADDLE_ENFORCE( + "the distribute training related op should be in [split_byref, " + "concat]."); + } + + PADDLE_ENFORCE(op_dev_id != -1, + "can not find right place for distributed op: %s", op.Type()); + + CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); } @@ -471,8 +492,34 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op) const { - result->ops_.emplace_back( - new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); + int op_dev_id = -1; + if (op.Type() == "send") { + op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + // the variable name which contains .block means it was splited by + // split_byref op + // so that we can balance the variable blocks to all the pserver instances. + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && + op.InputArgumentNames()[0].find(".block") == std::string::npos) { + op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); + for (auto &varname : op.InputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } + } else if (op.Type() == "recv") { + op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); + for (auto &varname : op.OutputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } + } else { + // send_barrier and fetch_barrier op can be scheduled on device 0 + op_dev_id = 0; + } + + PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", + op.Type()); + + result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], + op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { ConnectOp(result, result->ops_.back().get(), "send"); @@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, "send, send_barrier. recv, fetch_barrier]"); } - // TODO(Yancey1989): schedule rpc op on different place may - // increate throughput - CreateOpHandleIOs(result, op, 0); + CreateOpHandleIOs(result, op, op_dev_id); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 78581755fe4890800636944d6cd89875a852cc19..eb1c07630ab665a90d76b810a421cffb0ce673c2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { #endif std::unique_ptr Build(const ProgramDesc &program) const override; + int GetVarDeviceID(const std::string &varname) const; private: void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, - size_t place_id) const; + size_t device_id) const; private: std::string loss_var_name_; @@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID( - const std::vector> &var_name_on_devices, - const OpDesc &op) const; + int GetOpDeviceID(const OpDesc &op) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; - bool IsSparseGradient( - const std::unordered_map &all_vars, - const std::string &og) const; + bool IsSparseGradient(const std::string &og) const; + + size_t GetAppropriateDeviceID( + const std::vector &var_names) const; private: BuildStrategy strategy_; + mutable std::unordered_map all_vars_; + mutable std::unordered_map var_name_on_devices_; + mutable std::vector balance_vars_; void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 5fc12a44b51fae26e5a8f5fdba952d3879e82d0f..9eb23c46264f9036f009b0ae9aeeb34ec70c0e53 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -30,6 +30,7 @@ class SSAGraphBuilder { SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 9406c6155da860c90739bddac1e81403b094e619..d478865fa8f24c653a4185cabd05747a5410ceaa 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor( // Step 3. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp - details::SSAGraphBuilderFactory builder_factory( member_->places_, loss_var_name, params, member_->local_scopes_, build_strategy); @@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor( #endif } + builder_ = std::move(builder_factory.Create()); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->Build(main_program))); + builder_->Build(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), @@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToGPUs( const std::unordered_set &vars) const { - auto *main_scope = member_->local_scopes_[0]; + // the the initialize bcast, all vars would be bcast from device(0), otherwise + // bcast from the specified device. + bool initialize = builder_.get() == nullptr ? true : false; for (auto &var : vars) { - auto *main_var = main_scope->FindVar(var); + int var_dev_id = + builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); + if (!initialize && var_dev_id == -1) continue; + + framework::Variable *main_var = nullptr; + if (initialize) { + main_var = member_->local_scopes_[0]->FindVar(var); + } else { + main_var = member_->local_scopes_[var_dev_id]->FindVar(var); + } + if (main_var == nullptr || !main_var->IsType()) { continue; } @@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs( for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; void *buffer; - if (i == 0) { + + if ((initialize && i == 0) || (!initialize && i == var_dev_id)) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5247e790649e76567f4527d54499d6e95dac5c27..058f83f07c26224e3180d140630c08a24c40cd80 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -19,12 +19,14 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" + namespace paddle { namespace framework { @@ -68,6 +70,7 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; + std::unique_ptr builder_; }; } // namespace framework