diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 600705bb81f00f929f56ceba408909e84431034a..cc07cfc5527aee73245e566c624f72f679926701 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { - VLOG(3) << "Building ...."; std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars[var->Name()] = var; @@ -162,36 +161,32 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto send_vars = FindDistTrainSendVars(program); auto recv_vars = FindDistTrainRecvVars(program); - std::vector> var_name_on_devices; std::vector> bcast_var_name_set; - var_name_on_devices.resize(places_.size()); bcast_var_name_set.resize(places_.size()); size_t cur_device_id = 0; std::vector balance_grads(places_.size(), 0); - auto get_appropriate_dev = [&](std::string &g_name) -> size_t { - auto var_desc = all_vars.at(g_name); - PADDLE_ENFORCE_NOT_NULL(var_desc); - auto dim = framework::make_ddim(var_desc->GetShape()); - int64_t numel = framework::product(dim); - PADDLE_ENFORCE_GE(numel, 0); + auto get_appropriate_dev = [&](std::vector var_names) -> size_t { + int64_t numel_all = 0; + for (auto var_name : var_names) { + auto var_desc = all_vars.at(var_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GT(numel, 0); + numel_all += numel; + } + auto smallest = std::min_element(std::begin(balance_grads), std::end(balance_grads)); size_t dev_id = static_cast(std::distance(std::begin(balance_grads), smallest)); - balance_grads[dev_id] += numel; + balance_grads[dev_id] += numel_all; return dev_id; }; bool is_forwarding = true; - int rpc_op_device_id = 0; - auto schedule_rpc_op = [&]() -> void { - rpc_op_device_id++; - if (rpc_op_device_id >= static_cast(places_.size())) { - rpc_op_device_id = 0; - } - }; for (auto *op : program.Block(0).AllOps()) { if (boost::get( @@ -200,37 +195,40 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // append rpc op if program is distributed trainer main program. // always use the first device if (op->Type() == "send_vars") { - auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); - if (got == remote_vars_devices_.end()) { - schedule_rpc_op(); - } else { - rpc_op_device_id = got->second; + int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); + if (op_dev_id == -1) { + op_dev_id = get_appropriate_dev(op->InputArgumentNames()); + for (auto &varname : op->InputArgumentNames()) { + var_name_on_devices_.emplace(varname, op_dev_id); + } } - CreateRPCOp(&result, *op, rpc_op_device_id); + CreateRPCOp(&result, *op, op_dev_id); } else if (op->Type() == "recv") { - schedule_rpc_op(); + int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); for (auto &varname : op->OutputArgumentNames()) { - remote_vars_devices_.insert({varname, rpc_op_device_id}); + var_name_on_devices_.emplace(varname, op_dev_id); } - CreateRPCOp(&result, *op, rpc_op_device_id); + CreateRPCOp(&result, *op, op_dev_id); } else { + // send_barrier and fetch_barrier op would run on device 0 CreateRPCOp(&result, *op, 0); } } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { if (op->Type() == "split_byref") { - schedule_rpc_op(); + int op_dev_id = get_appropriate_dev(op->OutputArgumentNames()); for (auto &varname : op->OutputArgumentNames()) { - remote_vars_devices_.insert({varname, rpc_op_device_id}); + var_name_on_devices_.emplace(varname, op_dev_id); } - CreateDistTrainOp(&result, *op, rpc_op_device_id); - } - if (op->Type() == "concat") { - auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); - PADDLE_ENFORCE(got != remote_vars_devices_.end(), + CreateDistTrainOp(&result, *op, op_dev_id); + } else if (op->Type() == "concat") { + int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]); + PADDLE_ENFORCE(op_dev_id != -1, "can not find right place to concatenate received var."); - CreateDistTrainOp(&result, *op, got->second); + CreateDistTrainOp(&result, *op, op_dev_id); } else { - CreateDistTrainOp(&result, *op, 0); + PADDLE_ENFORCE( + "the distribute training related op should be in [split_byref, " + "concat]."); } } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ @@ -240,13 +238,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); + int op_dev_id = GetOpDeviceID(*op); if (op_dev_id == -1) { // var on all device CreateComputationalOps(&result, *op, places_.size()); } else { CreateComputationalOp(&result, *op, op_dev_id); for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices[op_dev_id].emplace(var_name); + var_name_on_devices_.emplace(var_name, op_dev_id); } } if (!is_forwarding && places_.size() > 1) { @@ -269,9 +267,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: - cur_device_id = get_appropriate_dev(g_name); + cur_device_id = get_appropriate_dev({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - var_name_on_devices[cur_device_id].emplace(g_name); + var_name_on_devices_.emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: @@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID( - const std::vector> &var_name_on_devices, - const OpDesc &op) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } - int var_dev_id = -1; - for (auto &var_name : op.InputArgumentNames()) { - if (var_dev_id != -1) break; - for (size_t i = 0; i < var_name_on_devices.size(); ++i) { - if (var_name_on_devices[i].count(var_name)) { - var_dev_id = static_cast(i); - break; - } + for (auto &varname : op.InputArgumentNames()) { + int dev_id = GetVarDeviceID(varname); + if (dev_id != -1) { + return dev_id; } } - return var_dev_id; + return -1; +} + +int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { + auto got = var_name_on_devices_.find(varname); + return got == var_name_on_devices_.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index b89686edd94680a11ddf06b64fa8828a07a514f8..a8a2045288f2c40a853776af6d33f5e64b5c9a1b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,14 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { #endif std::unique_ptr Build(const ProgramDesc &program) const override; - - int GetRemoteVarDeviceId(const std::string &var_name) const override { - auto got = remote_vars_devices_.find(var_name); - if (got != remote_vars_devices_.end()) { - return got->second; - } - return -1; - } + int GetVarDeviceID(const std::string &varname) const; private: void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, @@ -105,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID( - const std::vector> &var_name_on_devices, - const OpDesc &op) const; + int GetOpDeviceID(const OpDesc &op) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; @@ -120,7 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: BuildStrategy strategy_; - mutable std::unordered_map remote_vars_devices_; + mutable std::unordered_map var_name_on_devices_; void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 3c2c5273689267eb94b9727822cd12b8bc4fd124..9eb23c46264f9036f009b0ae9aeeb34ec70c0e53 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -30,9 +30,7 @@ class SSAGraphBuilder { SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; - virtual int GetRemoteVarDeviceId(const std::string &var_name) const { - return -1; - } + virtual int GetVarDeviceID(const std::string &var_name) const { return -1; } DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index e8c9ef29cad281e50861cd1daf12e2c793a7f711..5f4359357f8279cf1c537a0914e5395633407c7b 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -161,9 +161,8 @@ void ParallelExecutor::BCastParamsToGPUs( } auto &nccl_ctx = member_->nccl_ctxs_->at(place); - if (builder_.get() != nullptr && - builder_->GetRemoteVarDeviceId(var) != -1) { - int place_id = builder_->GetRemoteVarDeviceId(var); + if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) { + int place_id = builder_->GetVarDeviceID(var); platform::dynload::ncclBcast(buffer, numel, data_type, place_id, nccl_ctx.comm_, nccl_ctx.stream()); } else {